From 41caf2aee75067fae735724dfed84ba8972a6efd Mon Sep 17 00:00:00 2001 From: Henry Date: Mon, 9 Oct 2023 19:56:15 +0100 Subject: [PATCH] add memory fix --- .../agents/ConversationalAgent/ConversationalAgent.ts | 8 ++++++-- .../ConversationalRetrievalAgent.ts | 6 +++++- .../agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts | 8 ++++++-- .../nodes/chains/ConversationChain/ConversationChain.ts | 8 ++++++-- .../ConversationalRetrievalQAChain.ts | 6 +++++- 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/packages/components/nodes/agents/ConversationalAgent/ConversationalAgent.ts b/packages/components/nodes/agents/ConversationalAgent/ConversationalAgent.ts index 661ef151d..00f825d44 100644 --- a/packages/components/nodes/agents/ConversationalAgent/ConversationalAgent.ts +++ b/packages/components/nodes/agents/ConversationalAgent/ConversationalAgent.ts @@ -95,8 +95,12 @@ class ConversationalAgent_Agents implements INode { const callbacks = await additionalCallbacks(nodeData, options) if (options && options.chatHistory) { - memory.chatHistory = mapChatHistory(options) - executor.memory = memory + const chatHistoryClassName = memory.chatHistory.constructor.name + // Only replace when its In-Memory + if (chatHistoryClassName && chatHistoryClassName === 'ChatMessageHistory') { + memory.chatHistory = mapChatHistory(options) + executor.memory = memory + } } const result = await executor.call({ input }, [...callbacks]) diff --git a/packages/components/nodes/agents/ConversationalRetrievalAgent/ConversationalRetrievalAgent.ts b/packages/components/nodes/agents/ConversationalRetrievalAgent/ConversationalRetrievalAgent.ts index 3d70a2d32..4a908d7fe 100644 --- a/packages/components/nodes/agents/ConversationalRetrievalAgent/ConversationalRetrievalAgent.ts +++ b/packages/components/nodes/agents/ConversationalRetrievalAgent/ConversationalRetrievalAgent.ts @@ -82,7 +82,11 @@ class ConversationalRetrievalAgent_Agents implements INode { if (executor.memory) { ;(executor.memory as any).memoryKey = 'chat_history' ;(executor.memory as any).outputKey = 'output' - ;(executor.memory as any).chatHistory = mapChatHistory(options) + const chatHistoryClassName = (executor.memory as any).chatHistory.constructor.name + // Only replace when its In-Memory + if (chatHistoryClassName && chatHistoryClassName === 'ChatMessageHistory') { + ;(executor.memory as any).chatHistory = mapChatHistory(options) + } } const loggerHandler = new ConsoleCallbackHandler(options.logger) diff --git a/packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts b/packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts index c1bd32ec5..c920c399e 100644 --- a/packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts +++ b/packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts @@ -81,8 +81,12 @@ class OpenAIFunctionAgent_Agents implements INode { const memory = nodeData.inputs?.memory as BaseChatMemory if (options && options.chatHistory) { - memory.chatHistory = mapChatHistory(options) - executor.memory = memory + const chatHistoryClassName = memory.chatHistory.constructor.name + // Only replace when its In-Memory + if (chatHistoryClassName && chatHistoryClassName === 'ChatMessageHistory') { + memory.chatHistory = mapChatHistory(options) + executor.memory = memory + } } const loggerHandler = new ConsoleCallbackHandler(options.logger) diff --git a/packages/components/nodes/chains/ConversationChain/ConversationChain.ts b/packages/components/nodes/chains/ConversationChain/ConversationChain.ts index b26603e29..1cd15c9a5 100644 --- a/packages/components/nodes/chains/ConversationChain/ConversationChain.ts +++ b/packages/components/nodes/chains/ConversationChain/ConversationChain.ts @@ -106,8 +106,12 @@ class ConversationChain_Chains implements INode { const memory = nodeData.inputs?.memory as BufferMemory if (options && options.chatHistory) { - memory.chatHistory = mapChatHistory(options) - chain.memory = memory + const chatHistoryClassName = memory.chatHistory.constructor.name + // Only replace when its In-Memory + if (chatHistoryClassName && chatHistoryClassName === 'ChatMessageHistory') { + memory.chatHistory = mapChatHistory(options) + chain.memory = memory + } } const loggerHandler = new ConsoleCallbackHandler(options.logger) diff --git a/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts b/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts index 1b4675bda..9a8c1b188 100644 --- a/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts +++ b/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts @@ -179,7 +179,11 @@ class ConversationalRetrievalQAChain_Chains implements INode { const obj = { question: input } if (options && options.chatHistory && chain.memory) { - ;(chain.memory as any).chatHistory = mapChatHistory(options) + const chatHistoryClassName = (chain.memory as any).chatHistory.constructor.name + // Only replace when its In-Memory + if (chatHistoryClassName && chatHistoryClassName === 'ChatMessageHistory') { + ;(chain.memory as any).chatHistory = mapChatHistory(options) + } } const loggerHandler = new ConsoleCallbackHandler(options.logger)