add memory option

This commit is contained in:
Henry 2023-07-09 16:11:42 +01:00
parent 21552776d6
commit fb28f61f8b
3 changed files with 49 additions and 11 deletions

View File

@ -2,7 +2,7 @@ import { BaseLanguageModel } from 'langchain/base_language'
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface' import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
import { ConversationalRetrievalQAChain } from 'langchain/chains' import { ConversationalRetrievalQAChain } from 'langchain/chains'
import { AIChatMessage, BaseRetriever, HumanChatMessage } from 'langchain/schema' import { AIMessage, BaseRetriever, HumanMessage } from 'langchain/schema'
import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory' import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory'
import { PromptTemplate } from 'langchain/prompts' import { PromptTemplate } from 'langchain/prompts'
@ -20,6 +20,20 @@ const qa_template = `Use the following pieces of context to answer the question
Question: {question} Question: {question}
Helpful Answer:` Helpful Answer:`
const CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT = `Given the following conversation and a follow up question, return the conversation history excerpt that includes any relevant context to the question if it exists and rephrase the follow up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Your answer should follow the following format:
\`\`\`
Use the following pieces of context to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----------------
<Relevant chat history excerpt as context here>
Standalone question: <Rephrased question here>
\`\`\`
Your answer:`
class ConversationalRetrievalQAChain_Chains implements INode { class ConversationalRetrievalQAChain_Chains implements INode {
label: string label: string
name: string name: string
@ -49,6 +63,13 @@ class ConversationalRetrievalQAChain_Chains implements INode {
name: 'vectorStoreRetriever', name: 'vectorStoreRetriever',
type: 'BaseRetriever' type: 'BaseRetriever'
}, },
{
label: 'Memory',
name: 'memory',
type: 'DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory',
optional: true,
description: 'If no memory connected, default BufferMemory will be used'
},
{ {
label: 'Return Source Documents', label: 'Return Source Documents',
name: 'returnSourceDocuments', name: 'returnSourceDocuments',
@ -99,6 +120,7 @@ class ConversationalRetrievalQAChain_Chains implements INode {
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const chainOption = nodeData.inputs?.chainOption as string const chainOption = nodeData.inputs?.chainOption as string
const memory = nodeData.inputs?.memory
const obj: any = { const obj: any = {
verbose: process.env.DEBUG === 'true' ? true : false, verbose: process.env.DEBUG === 'true' ? true : false,
@ -106,15 +128,25 @@ class ConversationalRetrievalQAChain_Chains implements INode {
type: 'stuff', type: 'stuff',
prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template) prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template)
}, },
memory: new BufferMemory({ questionGeneratorChainOptions: {
template: CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT
}
}
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }
if (memory) {
memory.inputKey = 'question'
memory.outputKey = 'text'
memory.memoryKey = 'chat_history'
obj.memory = memory
} else {
obj.memory = new BufferMemory({
memoryKey: 'chat_history', memoryKey: 'chat_history',
inputKey: 'question', inputKey: 'question',
outputKey: 'text', outputKey: 'text',
returnMessages: true returnMessages: true
}) })
} }
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj) const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj)
return chain return chain
@ -123,6 +155,8 @@ class ConversationalRetrievalQAChain_Chains implements INode {
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> { async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
const chain = nodeData.instance as ConversationalRetrievalQAChain const chain = nodeData.instance as ConversationalRetrievalQAChain
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const memory = nodeData.inputs?.memory
let model = nodeData.inputs?.model let model = nodeData.inputs?.model
// Temporary fix: https://github.com/hwchase17/langchainjs/issues/754 // Temporary fix: https://github.com/hwchase17/langchainjs/issues/754
@ -131,16 +165,17 @@ class ConversationalRetrievalQAChain_Chains implements INode {
const obj = { question: input } const obj = { question: input }
if (chain.memory && options && options.chatHistory) { // If external memory like Zep, Redis is being used, ignore below
if (!memory && chain.memory && options && options.chatHistory) {
const chatHistory = [] const chatHistory = []
const histories: IMessage[] = options.chatHistory const histories: IMessage[] = options.chatHistory
const memory = chain.memory as BaseChatMemory const memory = chain.memory as BaseChatMemory
for (const message of histories) { for (const message of histories) {
if (message.type === 'apiMessage') { if (message.type === 'apiMessage') {
chatHistory.push(new AIChatMessage(message.message)) chatHistory.push(new AIMessage(message.message))
} else if (message.type === 'userMessage') { } else if (message.type === 'userMessage') {
chatHistory.push(new HumanChatMessage(message.message)) chatHistory.push(new HumanMessage(message.message))
} }
} }
memory.chatHistory = new ChatMessageHistory(chatHistory) memory.chatHistory = new ChatMessageHistory(chatHistory)

View File

@ -13,8 +13,9 @@ class DynamoDb_Memory implements INode {
inputs: INodeParams[] inputs: INodeParams[]
constructor() { constructor() {
this.label = 'DynamoDB Memory' this.label = 'DynamoDB Chat Memory'
this.name = 'DynamoDbMemory' this.name = 'DynamoDBChatMemory'
this.type = 'DynamoDBChatMemory'
this.icon = 'dynamodb.svg' this.icon = 'dynamodb.svg'
this.category = 'Memory' this.category = 'Memory'
this.description = 'Stores the conversation in dynamo db table' this.description = 'Stores the conversation in dynamo db table'

View File

@ -168,8 +168,10 @@ export const isValidConnection = (connection, reactFlowInstance) => {
//sourceHandle: "llmChain_0-output-llmChain-BaseChain" //sourceHandle: "llmChain_0-output-llmChain-BaseChain"
//targetHandle: "mrlkAgentLLM_0-input-model-BaseLanguageModel" //targetHandle: "mrlkAgentLLM_0-input-model-BaseLanguageModel"
const sourceTypes = sourceHandle.split('-')[sourceHandle.split('-').length - 1].split('|') let sourceTypes = sourceHandle.split('-')[sourceHandle.split('-').length - 1].split('|')
const targetTypes = targetHandle.split('-')[targetHandle.split('-').length - 1].split('|') sourceTypes = sourceTypes.map((s) => s.trim())
let targetTypes = targetHandle.split('-')[targetHandle.split('-').length - 1].split('|')
targetTypes = targetTypes.map((t) => t.trim())
if (targetTypes.some((t) => sourceTypes.includes(t))) { if (targetTypes.some((t) => sourceTypes.includes(t))) {
let targetNode = reactFlowInstance.getNode(target) let targetNode = reactFlowInstance.getNode(target)