From fb28f61f8b966163c6d8d653a0e270a894e04723 Mon Sep 17 00:00:00 2001 From: Henry Date: Sun, 9 Jul 2023 16:11:42 +0100 Subject: [PATCH] add memory option --- .../ConversationalRetrievalQAChain.ts | 49 ++++++++++++++++--- .../nodes/memory/DynamoDb/DynamoDb.ts | 5 +- packages/ui/src/utils/genericHelper.js | 6 ++- 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts b/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts index 3b7e1413f..4872717f9 100644 --- a/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts +++ b/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts @@ -2,7 +2,7 @@ import { BaseLanguageModel } from 'langchain/base_language' import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface' import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { ConversationalRetrievalQAChain } from 'langchain/chains' -import { AIChatMessage, BaseRetriever, HumanChatMessage } from 'langchain/schema' +import { AIMessage, BaseRetriever, HumanMessage } from 'langchain/schema' import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory' import { PromptTemplate } from 'langchain/prompts' @@ -20,6 +20,20 @@ const qa_template = `Use the following pieces of context to answer the question Question: {question} Helpful Answer:` +const CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT = `Given the following conversation and a follow up question, return the conversation history excerpt that includes any relevant context to the question if it exists and rephrase the follow up question to be a standalone question. +Chat History: +{chat_history} +Follow Up Input: {question} +Your answer should follow the following format: +\`\`\` +Use the following pieces of context to answer the users question. +If you don't know the answer, just say that you don't know, don't try to make up an answer. +---------------- + +Standalone question: +\`\`\` +Your answer:` + class ConversationalRetrievalQAChain_Chains implements INode { label: string name: string @@ -49,6 +63,13 @@ class ConversationalRetrievalQAChain_Chains implements INode { name: 'vectorStoreRetriever', type: 'BaseRetriever' }, + { + label: 'Memory', + name: 'memory', + type: 'DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory', + optional: true, + description: 'If no memory connected, default BufferMemory will be used' + }, { label: 'Return Source Documents', name: 'returnSourceDocuments', @@ -99,6 +120,7 @@ class ConversationalRetrievalQAChain_Chains implements INode { const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean const chainOption = nodeData.inputs?.chainOption as string + const memory = nodeData.inputs?.memory const obj: any = { verbose: process.env.DEBUG === 'true' ? true : false, @@ -106,15 +128,25 @@ class ConversationalRetrievalQAChain_Chains implements INode { type: 'stuff', prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template) }, - memory: new BufferMemory({ + questionGeneratorChainOptions: { + template: CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT + } + } + if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments + if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption } + if (memory) { + memory.inputKey = 'question' + memory.outputKey = 'text' + memory.memoryKey = 'chat_history' + obj.memory = memory + } else { + obj.memory = new BufferMemory({ memoryKey: 'chat_history', inputKey: 'question', outputKey: 'text', returnMessages: true }) } - if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments - if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption } const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj) return chain @@ -123,6 +155,8 @@ class ConversationalRetrievalQAChain_Chains implements INode { async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const chain = nodeData.instance as ConversationalRetrievalQAChain const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean + const memory = nodeData.inputs?.memory + let model = nodeData.inputs?.model // Temporary fix: https://github.com/hwchase17/langchainjs/issues/754 @@ -131,16 +165,17 @@ class ConversationalRetrievalQAChain_Chains implements INode { const obj = { question: input } - if (chain.memory && options && options.chatHistory) { + // If external memory like Zep, Redis is being used, ignore below + if (!memory && chain.memory && options && options.chatHistory) { const chatHistory = [] const histories: IMessage[] = options.chatHistory const memory = chain.memory as BaseChatMemory for (const message of histories) { if (message.type === 'apiMessage') { - chatHistory.push(new AIChatMessage(message.message)) + chatHistory.push(new AIMessage(message.message)) } else if (message.type === 'userMessage') { - chatHistory.push(new HumanChatMessage(message.message)) + chatHistory.push(new HumanMessage(message.message)) } } memory.chatHistory = new ChatMessageHistory(chatHistory) diff --git a/packages/components/nodes/memory/DynamoDb/DynamoDb.ts b/packages/components/nodes/memory/DynamoDb/DynamoDb.ts index b13680443..49d15cb61 100644 --- a/packages/components/nodes/memory/DynamoDb/DynamoDb.ts +++ b/packages/components/nodes/memory/DynamoDb/DynamoDb.ts @@ -13,8 +13,9 @@ class DynamoDb_Memory implements INode { inputs: INodeParams[] constructor() { - this.label = 'DynamoDB Memory' - this.name = 'DynamoDbMemory' + this.label = 'DynamoDB Chat Memory' + this.name = 'DynamoDBChatMemory' + this.type = 'DynamoDBChatMemory' this.icon = 'dynamodb.svg' this.category = 'Memory' this.description = 'Stores the conversation in dynamo db table' diff --git a/packages/ui/src/utils/genericHelper.js b/packages/ui/src/utils/genericHelper.js index 42a630575..305326f7a 100644 --- a/packages/ui/src/utils/genericHelper.js +++ b/packages/ui/src/utils/genericHelper.js @@ -168,8 +168,10 @@ export const isValidConnection = (connection, reactFlowInstance) => { //sourceHandle: "llmChain_0-output-llmChain-BaseChain" //targetHandle: "mrlkAgentLLM_0-input-model-BaseLanguageModel" - const sourceTypes = sourceHandle.split('-')[sourceHandle.split('-').length - 1].split('|') - const targetTypes = targetHandle.split('-')[targetHandle.split('-').length - 1].split('|') + let sourceTypes = sourceHandle.split('-')[sourceHandle.split('-').length - 1].split('|') + sourceTypes = sourceTypes.map((s) => s.trim()) + let targetTypes = targetHandle.split('-')[targetHandle.split('-').length - 1].split('|') + targetTypes = targetTypes.map((t) => t.trim()) if (targetTypes.some((t) => sourceTypes.includes(t))) { let targetNode = reactFlowInstance.getNode(target)