add memory option
This commit is contained in:
parent
21552776d6
commit
fb28f61f8b
|
|
@ -2,7 +2,7 @@ import { BaseLanguageModel } from 'langchain/base_language'
|
|||
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
|
||||
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
|
||||
import { ConversationalRetrievalQAChain } from 'langchain/chains'
|
||||
import { AIChatMessage, BaseRetriever, HumanChatMessage } from 'langchain/schema'
|
||||
import { AIMessage, BaseRetriever, HumanMessage } from 'langchain/schema'
|
||||
import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory'
|
||||
import { PromptTemplate } from 'langchain/prompts'
|
||||
|
||||
|
|
@ -20,6 +20,20 @@ const qa_template = `Use the following pieces of context to answer the question
|
|||
Question: {question}
|
||||
Helpful Answer:`
|
||||
|
||||
const CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT = `Given the following conversation and a follow up question, return the conversation history excerpt that includes any relevant context to the question if it exists and rephrase the follow up question to be a standalone question.
|
||||
Chat History:
|
||||
{chat_history}
|
||||
Follow Up Input: {question}
|
||||
Your answer should follow the following format:
|
||||
\`\`\`
|
||||
Use the following pieces of context to answer the users question.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
----------------
|
||||
<Relevant chat history excerpt as context here>
|
||||
Standalone question: <Rephrased question here>
|
||||
\`\`\`
|
||||
Your answer:`
|
||||
|
||||
class ConversationalRetrievalQAChain_Chains implements INode {
|
||||
label: string
|
||||
name: string
|
||||
|
|
@ -49,6 +63,13 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
|||
name: 'vectorStoreRetriever',
|
||||
type: 'BaseRetriever'
|
||||
},
|
||||
{
|
||||
label: 'Memory',
|
||||
name: 'memory',
|
||||
type: 'DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory',
|
||||
optional: true,
|
||||
description: 'If no memory connected, default BufferMemory will be used'
|
||||
},
|
||||
{
|
||||
label: 'Return Source Documents',
|
||||
name: 'returnSourceDocuments',
|
||||
|
|
@ -99,6 +120,7 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
|||
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
|
||||
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
|
||||
const chainOption = nodeData.inputs?.chainOption as string
|
||||
const memory = nodeData.inputs?.memory
|
||||
|
||||
const obj: any = {
|
||||
verbose: process.env.DEBUG === 'true' ? true : false,
|
||||
|
|
@ -106,15 +128,25 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
|||
type: 'stuff',
|
||||
prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template)
|
||||
},
|
||||
memory: new BufferMemory({
|
||||
questionGeneratorChainOptions: {
|
||||
template: CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT
|
||||
}
|
||||
}
|
||||
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
|
||||
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }
|
||||
if (memory) {
|
||||
memory.inputKey = 'question'
|
||||
memory.outputKey = 'text'
|
||||
memory.memoryKey = 'chat_history'
|
||||
obj.memory = memory
|
||||
} else {
|
||||
obj.memory = new BufferMemory({
|
||||
memoryKey: 'chat_history',
|
||||
inputKey: 'question',
|
||||
outputKey: 'text',
|
||||
returnMessages: true
|
||||
})
|
||||
}
|
||||
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
|
||||
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }
|
||||
|
||||
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj)
|
||||
return chain
|
||||
|
|
@ -123,6 +155,8 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
|||
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
|
||||
const chain = nodeData.instance as ConversationalRetrievalQAChain
|
||||
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
|
||||
const memory = nodeData.inputs?.memory
|
||||
|
||||
let model = nodeData.inputs?.model
|
||||
|
||||
// Temporary fix: https://github.com/hwchase17/langchainjs/issues/754
|
||||
|
|
@ -131,16 +165,17 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
|||
|
||||
const obj = { question: input }
|
||||
|
||||
if (chain.memory && options && options.chatHistory) {
|
||||
// If external memory like Zep, Redis is being used, ignore below
|
||||
if (!memory && chain.memory && options && options.chatHistory) {
|
||||
const chatHistory = []
|
||||
const histories: IMessage[] = options.chatHistory
|
||||
const memory = chain.memory as BaseChatMemory
|
||||
|
||||
for (const message of histories) {
|
||||
if (message.type === 'apiMessage') {
|
||||
chatHistory.push(new AIChatMessage(message.message))
|
||||
chatHistory.push(new AIMessage(message.message))
|
||||
} else if (message.type === 'userMessage') {
|
||||
chatHistory.push(new HumanChatMessage(message.message))
|
||||
chatHistory.push(new HumanMessage(message.message))
|
||||
}
|
||||
}
|
||||
memory.chatHistory = new ChatMessageHistory(chatHistory)
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ class DynamoDb_Memory implements INode {
|
|||
inputs: INodeParams[]
|
||||
|
||||
constructor() {
|
||||
this.label = 'DynamoDB Memory'
|
||||
this.name = 'DynamoDbMemory'
|
||||
this.label = 'DynamoDB Chat Memory'
|
||||
this.name = 'DynamoDBChatMemory'
|
||||
this.type = 'DynamoDBChatMemory'
|
||||
this.icon = 'dynamodb.svg'
|
||||
this.category = 'Memory'
|
||||
this.description = 'Stores the conversation in dynamo db table'
|
||||
|
|
|
|||
|
|
@ -168,8 +168,10 @@ export const isValidConnection = (connection, reactFlowInstance) => {
|
|||
//sourceHandle: "llmChain_0-output-llmChain-BaseChain"
|
||||
//targetHandle: "mrlkAgentLLM_0-input-model-BaseLanguageModel"
|
||||
|
||||
const sourceTypes = sourceHandle.split('-')[sourceHandle.split('-').length - 1].split('|')
|
||||
const targetTypes = targetHandle.split('-')[targetHandle.split('-').length - 1].split('|')
|
||||
let sourceTypes = sourceHandle.split('-')[sourceHandle.split('-').length - 1].split('|')
|
||||
sourceTypes = sourceTypes.map((s) => s.trim())
|
||||
let targetTypes = targetHandle.split('-')[targetHandle.split('-').length - 1].split('|')
|
||||
targetTypes = targetTypes.map((t) => t.trim())
|
||||
|
||||
if (targetTypes.some((t) => sourceTypes.includes(t))) {
|
||||
let targetNode = reactFlowInstance.getNode(target)
|
||||
|
|
|
|||
Loading…
Reference in New Issue