add feature/EntityMemory init
This commit is contained in:
parent
2968dccd83
commit
75d561b89b
|
|
@ -2,7 +2,7 @@ import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../..
|
||||||
import { ConversationChain } from 'langchain/chains'
|
import { ConversationChain } from 'langchain/chains'
|
||||||
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
|
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
|
||||||
import { ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate } from 'langchain/prompts'
|
import { ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate } from 'langchain/prompts'
|
||||||
import { BufferMemory, ChatMessageHistory } from 'langchain/memory'
|
import { BufferMemory, ChatMessageHistory, ENTITY_MEMORY_CONVERSATION_TEMPLATE, EntityMemory } from 'langchain/memory'
|
||||||
import { BaseChatModel } from 'langchain/chat_models/base'
|
import { BaseChatModel } from 'langchain/chat_models/base'
|
||||||
import { AIChatMessage, HumanChatMessage } from 'langchain/schema'
|
import { AIChatMessage, HumanChatMessage } from 'langchain/schema'
|
||||||
|
|
||||||
|
|
@ -51,7 +51,7 @@ class ConversationChain_Chains implements INode {
|
||||||
|
|
||||||
async init(nodeData: INodeData): Promise<any> {
|
async init(nodeData: INodeData): Promise<any> {
|
||||||
const model = nodeData.inputs?.model as BaseChatModel
|
const model = nodeData.inputs?.model as BaseChatModel
|
||||||
const memory = nodeData.inputs?.memory as BufferMemory
|
const memory = nodeData.inputs?.memory as BufferMemory | EntityMemory
|
||||||
const prompt = nodeData.inputs?.systemMessagePrompt as string
|
const prompt = nodeData.inputs?.systemMessagePrompt as string
|
||||||
|
|
||||||
const obj: any = {
|
const obj: any = {
|
||||||
|
|
@ -60,11 +60,17 @@ class ConversationChain_Chains implements INode {
|
||||||
verbose: process.env.DEBUG === 'true' ? true : false
|
verbose: process.env.DEBUG === 'true' ? true : false
|
||||||
}
|
}
|
||||||
|
|
||||||
const chatPrompt = ChatPromptTemplate.fromPromptMessages([
|
let chatPrompt: any
|
||||||
SystemMessagePromptTemplate.fromTemplate(prompt ? `${prompt}\n${systemMessage}` : systemMessage),
|
if (memory instanceof EntityMemory) {
|
||||||
new MessagesPlaceholder(memory.memoryKey ?? 'chat_history'),
|
chatPrompt = ENTITY_MEMORY_CONVERSATION_TEMPLATE
|
||||||
HumanMessagePromptTemplate.fromTemplate('{input}')
|
console.log('use ENTITY_MEMORY_CONVERSATION_TEMPLATE')
|
||||||
])
|
} else {
|
||||||
|
chatPrompt = ChatPromptTemplate.fromPromptMessages([
|
||||||
|
SystemMessagePromptTemplate.fromTemplate(prompt ? `${prompt}\n${systemMessage}` : systemMessage),
|
||||||
|
new MessagesPlaceholder(memory.memoryKey ?? 'chat_history'),
|
||||||
|
HumanMessagePromptTemplate.fromTemplate('{input}')
|
||||||
|
])
|
||||||
|
}
|
||||||
obj.prompt = chatPrompt
|
obj.prompt = chatPrompt
|
||||||
|
|
||||||
const chain = new ConversationChain(obj)
|
const chain = new ConversationChain(obj)
|
||||||
|
|
@ -73,9 +79,10 @@ class ConversationChain_Chains implements INode {
|
||||||
|
|
||||||
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
|
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
|
||||||
const chain = nodeData.instance as ConversationChain
|
const chain = nodeData.instance as ConversationChain
|
||||||
const memory = nodeData.inputs?.memory as BufferMemory
|
const memory = nodeData.inputs?.memory as BufferMemory | EntityMemory
|
||||||
|
|
||||||
if (options && options.chatHistory) {
|
if (options && options.chatHistory) {
|
||||||
|
console.log(`Went into options.chatHistory`)
|
||||||
const chatHistory = []
|
const chatHistory = []
|
||||||
const histories: IMessage[] = options.chatHistory
|
const histories: IMessage[] = options.chatHistory
|
||||||
|
|
||||||
|
|
@ -90,12 +97,22 @@ class ConversationChain_Chains implements INode {
|
||||||
chain.memory = memory
|
chain.memory = memory
|
||||||
}
|
}
|
||||||
|
|
||||||
|
console.log(`chain.memory is instanceof EntityMemory: ${chain.memory instanceof EntityMemory}`)
|
||||||
|
|
||||||
if (options.socketIO && options.socketIOClientId) {
|
if (options.socketIO && options.socketIOClientId) {
|
||||||
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
|
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
|
||||||
const res = await chain.call({ input }, [handler])
|
const res = await chain.call({ input }, [handler])
|
||||||
|
if (memory instanceof EntityMemory) console.log({
|
||||||
|
res,
|
||||||
|
memory: await memory.loadMemoryVariables({ input: "Who is Jim?" }),
|
||||||
|
})
|
||||||
return res?.response
|
return res?.response
|
||||||
} else {
|
} else {
|
||||||
const res = await chain.call({ input })
|
const res = await chain.call({ input })
|
||||||
|
if (memory instanceof EntityMemory) console.log({
|
||||||
|
res,
|
||||||
|
memory: await memory.loadMemoryVariables({ input: "Who is Jim?" }),
|
||||||
|
})
|
||||||
return res?.text
|
return res?.text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
import { EntityMemoryInput } from 'langchain/dist/memory/entity_memory'
|
||||||
|
import { INode, INodeData, INodeParams } from '../../../src/Interface'
|
||||||
|
import { getBaseClasses } from '../../../src/utils'
|
||||||
|
import { BaseChatMemoryInput, BufferMemory, EntityMemory } from 'langchain/memory'
|
||||||
|
import { BaseLanguageModel } from 'langchain/base_language'
|
||||||
|
|
||||||
|
class EntityMemory_Memory implements INode {
|
||||||
|
label: string
|
||||||
|
name: string
|
||||||
|
description: string
|
||||||
|
type: string
|
||||||
|
icon: string
|
||||||
|
category: string
|
||||||
|
baseClasses: string[]
|
||||||
|
inputs: INodeParams[]
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
this.label = 'Entity Memory'
|
||||||
|
this.name = 'entityMemory'
|
||||||
|
this.type = 'EntityMemory'
|
||||||
|
this.icon = 'memory.svg'
|
||||||
|
this.category = 'Memory'
|
||||||
|
this.description = 'Using an LLM to extracts information on entities and builds up its knowledge about that entity over time'
|
||||||
|
this.baseClasses = [this.type, ...getBaseClasses(EntityMemory)]
|
||||||
|
this.inputs = [
|
||||||
|
{
|
||||||
|
label: 'Language Model',
|
||||||
|
name: 'model',
|
||||||
|
type: 'BaseLanguageModel'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Chat History Key',
|
||||||
|
name: 'chatHistoryKey',
|
||||||
|
type: 'string',
|
||||||
|
default: 'history'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Entities Key',
|
||||||
|
name: 'entitiesKey',
|
||||||
|
type: 'string',
|
||||||
|
default: 'entities'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
async init(nodeData: INodeData): Promise<any> {
|
||||||
|
const model = nodeData.inputs?.model as BaseLanguageModel
|
||||||
|
const chatHistoryKey = nodeData.inputs?.chatHistoryKey as string
|
||||||
|
const entitiesKey = nodeData.inputs?.entitiesKey as string
|
||||||
|
|
||||||
|
const obj: EntityMemoryInput = {
|
||||||
|
llm: model,
|
||||||
|
returnMessages: true,
|
||||||
|
chatHistoryKey,
|
||||||
|
entitiesKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return new EntityMemory(obj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = { nodeClass: EntityMemory_Memory }
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-book" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
|
||||||
|
<path stroke="none" d="M0 0h24v24H0z" fill="none"></path>
|
||||||
|
<path d="M3 19a9 9 0 0 1 9 0a9 9 0 0 1 9 0"></path>
|
||||||
|
<path d="M3 6a9 9 0 0 1 9 0a9 9 0 0 1 9 0"></path>
|
||||||
|
<path d="M3 6l0 13"></path>
|
||||||
|
<path d="M12 6l0 13"></path>
|
||||||
|
<path d="M21 6l0 13"></path>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 495 B |
Loading…
Reference in New Issue