diff --git a/packages/components/nodes/chains/ConversationChain/ConversationChain.ts b/packages/components/nodes/chains/ConversationChain/ConversationChain.ts index 4e39ae6db..4da707cdc 100644 --- a/packages/components/nodes/chains/ConversationChain/ConversationChain.ts +++ b/packages/components/nodes/chains/ConversationChain/ConversationChain.ts @@ -2,7 +2,7 @@ import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../.. import { ConversationChain } from 'langchain/chains' import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate } from 'langchain/prompts' -import { BufferMemory, ChatMessageHistory } from 'langchain/memory' +import { BufferMemory, ChatMessageHistory, ENTITY_MEMORY_CONVERSATION_TEMPLATE, EntityMemory } from 'langchain/memory' import { BaseChatModel } from 'langchain/chat_models/base' import { AIChatMessage, HumanChatMessage } from 'langchain/schema' @@ -51,7 +51,7 @@ class ConversationChain_Chains implements INode { async init(nodeData: INodeData): Promise { const model = nodeData.inputs?.model as BaseChatModel - const memory = nodeData.inputs?.memory as BufferMemory + const memory = nodeData.inputs?.memory as BufferMemory | EntityMemory const prompt = nodeData.inputs?.systemMessagePrompt as string const obj: any = { @@ -60,11 +60,17 @@ class ConversationChain_Chains implements INode { verbose: process.env.DEBUG === 'true' ? true : false } - const chatPrompt = ChatPromptTemplate.fromPromptMessages([ - SystemMessagePromptTemplate.fromTemplate(prompt ? `${prompt}\n${systemMessage}` : systemMessage), - new MessagesPlaceholder(memory.memoryKey ?? 'chat_history'), - HumanMessagePromptTemplate.fromTemplate('{input}') - ]) + let chatPrompt: any + if (memory instanceof EntityMemory) { + chatPrompt = ENTITY_MEMORY_CONVERSATION_TEMPLATE + console.log('use ENTITY_MEMORY_CONVERSATION_TEMPLATE') + } else { + chatPrompt = ChatPromptTemplate.fromPromptMessages([ + SystemMessagePromptTemplate.fromTemplate(prompt ? `${prompt}\n${systemMessage}` : systemMessage), + new MessagesPlaceholder(memory.memoryKey ?? 'chat_history'), + HumanMessagePromptTemplate.fromTemplate('{input}') + ]) + } obj.prompt = chatPrompt const chain = new ConversationChain(obj) @@ -73,9 +79,10 @@ class ConversationChain_Chains implements INode { async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const chain = nodeData.instance as ConversationChain - const memory = nodeData.inputs?.memory as BufferMemory + const memory = nodeData.inputs?.memory as BufferMemory | EntityMemory if (options && options.chatHistory) { + console.log(`Went into options.chatHistory`) const chatHistory = [] const histories: IMessage[] = options.chatHistory @@ -90,12 +97,22 @@ class ConversationChain_Chains implements INode { chain.memory = memory } + console.log(`chain.memory is instanceof EntityMemory: ${chain.memory instanceof EntityMemory}`) + if (options.socketIO && options.socketIOClientId) { const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) const res = await chain.call({ input }, [handler]) + if (memory instanceof EntityMemory) console.log({ + res, + memory: await memory.loadMemoryVariables({ input: "Who is Jim?" }), + }) return res?.response } else { const res = await chain.call({ input }) + if (memory instanceof EntityMemory) console.log({ + res, + memory: await memory.loadMemoryVariables({ input: "Who is Jim?" }), + }) return res?.text } } diff --git a/packages/components/nodes/memory/EntityMemory/EntityMemory.ts b/packages/components/nodes/memory/EntityMemory/EntityMemory.ts new file mode 100644 index 000000000..b7ee4c262 --- /dev/null +++ b/packages/components/nodes/memory/EntityMemory/EntityMemory.ts @@ -0,0 +1,62 @@ +import { EntityMemoryInput } from 'langchain/dist/memory/entity_memory' +import { INode, INodeData, INodeParams } from '../../../src/Interface' +import { getBaseClasses } from '../../../src/utils' +import { BaseChatMemoryInput, BufferMemory, EntityMemory } from 'langchain/memory' +import { BaseLanguageModel } from 'langchain/base_language' + +class EntityMemory_Memory implements INode { + label: string + name: string + description: string + type: string + icon: string + category: string + baseClasses: string[] + inputs: INodeParams[] + + constructor() { + this.label = 'Entity Memory' + this.name = 'entityMemory' + this.type = 'EntityMemory' + this.icon = 'memory.svg' + this.category = 'Memory' + this.description = 'Using an LLM to extracts information on entities and builds up its knowledge about that entity over time' + this.baseClasses = [this.type, ...getBaseClasses(EntityMemory)] + this.inputs = [ + { + label: 'Language Model', + name: 'model', + type: 'BaseLanguageModel' + }, + { + label: 'Chat History Key', + name: 'chatHistoryKey', + type: 'string', + default: 'history' + }, + { + label: 'Entities Key', + name: 'entitiesKey', + type: 'string', + default: 'entities' + } + ] + } + + async init(nodeData: INodeData): Promise { + const model = nodeData.inputs?.model as BaseLanguageModel + const chatHistoryKey = nodeData.inputs?.chatHistoryKey as string + const entitiesKey = nodeData.inputs?.entitiesKey as string + + const obj: EntityMemoryInput = { + llm: model, + returnMessages: true, + chatHistoryKey, + entitiesKey + } + + return new EntityMemory(obj) + } +} + +module.exports = { nodeClass: EntityMemory_Memory } diff --git a/packages/components/nodes/memory/EntityMemory/memory.svg b/packages/components/nodes/memory/EntityMemory/memory.svg new file mode 100644 index 000000000..ca8e17da1 --- /dev/null +++ b/packages/components/nodes/memory/EntityMemory/memory.svg @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file