add feature/EntityMemory init

This commit is contained in:
chungyau97 2023-05-26 18:05:14 +07:00
parent 2968dccd83
commit 75d561b89b
3 changed files with 95 additions and 8 deletions

View File

@ -2,7 +2,7 @@ import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../..
import { ConversationChain } from 'langchain/chains'
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
import { ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate } from 'langchain/prompts'
import { BufferMemory, ChatMessageHistory } from 'langchain/memory'
import { BufferMemory, ChatMessageHistory, ENTITY_MEMORY_CONVERSATION_TEMPLATE, EntityMemory } from 'langchain/memory'
import { BaseChatModel } from 'langchain/chat_models/base'
import { AIChatMessage, HumanChatMessage } from 'langchain/schema'
@ -51,7 +51,7 @@ class ConversationChain_Chains implements INode {
async init(nodeData: INodeData): Promise<any> {
const model = nodeData.inputs?.model as BaseChatModel
const memory = nodeData.inputs?.memory as BufferMemory
const memory = nodeData.inputs?.memory as BufferMemory | EntityMemory
const prompt = nodeData.inputs?.systemMessagePrompt as string
const obj: any = {
@ -60,11 +60,17 @@ class ConversationChain_Chains implements INode {
verbose: process.env.DEBUG === 'true' ? true : false
}
const chatPrompt = ChatPromptTemplate.fromPromptMessages([
SystemMessagePromptTemplate.fromTemplate(prompt ? `${prompt}\n${systemMessage}` : systemMessage),
new MessagesPlaceholder(memory.memoryKey ?? 'chat_history'),
HumanMessagePromptTemplate.fromTemplate('{input}')
])
let chatPrompt: any
if (memory instanceof EntityMemory) {
chatPrompt = ENTITY_MEMORY_CONVERSATION_TEMPLATE
console.log('use ENTITY_MEMORY_CONVERSATION_TEMPLATE')
} else {
chatPrompt = ChatPromptTemplate.fromPromptMessages([
SystemMessagePromptTemplate.fromTemplate(prompt ? `${prompt}\n${systemMessage}` : systemMessage),
new MessagesPlaceholder(memory.memoryKey ?? 'chat_history'),
HumanMessagePromptTemplate.fromTemplate('{input}')
])
}
obj.prompt = chatPrompt
const chain = new ConversationChain(obj)
@ -73,9 +79,10 @@ class ConversationChain_Chains implements INode {
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
const chain = nodeData.instance as ConversationChain
const memory = nodeData.inputs?.memory as BufferMemory
const memory = nodeData.inputs?.memory as BufferMemory | EntityMemory
if (options && options.chatHistory) {
console.log(`Went into options.chatHistory`)
const chatHistory = []
const histories: IMessage[] = options.chatHistory
@ -90,12 +97,22 @@ class ConversationChain_Chains implements INode {
chain.memory = memory
}
console.log(`chain.memory is instanceof EntityMemory: ${chain.memory instanceof EntityMemory}`)
if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
const res = await chain.call({ input }, [handler])
if (memory instanceof EntityMemory) console.log({
res,
memory: await memory.loadMemoryVariables({ input: "Who is Jim?" }),
})
return res?.response
} else {
const res = await chain.call({ input })
if (memory instanceof EntityMemory) console.log({
res,
memory: await memory.loadMemoryVariables({ input: "Who is Jim?" }),
})
return res?.text
}
}

View File

@ -0,0 +1,62 @@
import { EntityMemoryInput } from 'langchain/dist/memory/entity_memory'
import { INode, INodeData, INodeParams } from '../../../src/Interface'
import { getBaseClasses } from '../../../src/utils'
import { BaseChatMemoryInput, BufferMemory, EntityMemory } from 'langchain/memory'
import { BaseLanguageModel } from 'langchain/base_language'
class EntityMemory_Memory implements INode {
label: string
name: string
description: string
type: string
icon: string
category: string
baseClasses: string[]
inputs: INodeParams[]
constructor() {
this.label = 'Entity Memory'
this.name = 'entityMemory'
this.type = 'EntityMemory'
this.icon = 'memory.svg'
this.category = 'Memory'
this.description = 'Using an LLM to extracts information on entities and builds up its knowledge about that entity over time'
this.baseClasses = [this.type, ...getBaseClasses(EntityMemory)]
this.inputs = [
{
label: 'Language Model',
name: 'model',
type: 'BaseLanguageModel'
},
{
label: 'Chat History Key',
name: 'chatHistoryKey',
type: 'string',
default: 'history'
},
{
label: 'Entities Key',
name: 'entitiesKey',
type: 'string',
default: 'entities'
}
]
}
async init(nodeData: INodeData): Promise<any> {
const model = nodeData.inputs?.model as BaseLanguageModel
const chatHistoryKey = nodeData.inputs?.chatHistoryKey as string
const entitiesKey = nodeData.inputs?.entitiesKey as string
const obj: EntityMemoryInput = {
llm: model,
returnMessages: true,
chatHistoryKey,
entitiesKey
}
return new EntityMemory(obj)
}
}
module.exports = { nodeClass: EntityMemory_Memory }

View File

@ -0,0 +1,8 @@
<svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-book" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"></path>
<path d="M3 19a9 9 0 0 1 9 0a9 9 0 0 1 9 0"></path>
<path d="M3 6a9 9 0 0 1 9 0a9 9 0 0 1 9 0"></path>
<path d="M3 6l0 13"></path>
<path d="M12 6l0 13"></path>
<path d="M21 6l0 13"></path>
</svg>

After

Width:  |  Height:  |  Size: 495 B