add memory option
This commit is contained in:
parent
21552776d6
commit
fb28f61f8b
|
|
@ -2,7 +2,7 @@ import { BaseLanguageModel } from 'langchain/base_language'
|
||||||
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
|
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
|
||||||
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
|
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
|
||||||
import { ConversationalRetrievalQAChain } from 'langchain/chains'
|
import { ConversationalRetrievalQAChain } from 'langchain/chains'
|
||||||
import { AIChatMessage, BaseRetriever, HumanChatMessage } from 'langchain/schema'
|
import { AIMessage, BaseRetriever, HumanMessage } from 'langchain/schema'
|
||||||
import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory'
|
import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory'
|
||||||
import { PromptTemplate } from 'langchain/prompts'
|
import { PromptTemplate } from 'langchain/prompts'
|
||||||
|
|
||||||
|
|
@ -20,6 +20,20 @@ const qa_template = `Use the following pieces of context to answer the question
|
||||||
Question: {question}
|
Question: {question}
|
||||||
Helpful Answer:`
|
Helpful Answer:`
|
||||||
|
|
||||||
|
const CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT = `Given the following conversation and a follow up question, return the conversation history excerpt that includes any relevant context to the question if it exists and rephrase the follow up question to be a standalone question.
|
||||||
|
Chat History:
|
||||||
|
{chat_history}
|
||||||
|
Follow Up Input: {question}
|
||||||
|
Your answer should follow the following format:
|
||||||
|
\`\`\`
|
||||||
|
Use the following pieces of context to answer the users question.
|
||||||
|
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||||
|
----------------
|
||||||
|
<Relevant chat history excerpt as context here>
|
||||||
|
Standalone question: <Rephrased question here>
|
||||||
|
\`\`\`
|
||||||
|
Your answer:`
|
||||||
|
|
||||||
class ConversationalRetrievalQAChain_Chains implements INode {
|
class ConversationalRetrievalQAChain_Chains implements INode {
|
||||||
label: string
|
label: string
|
||||||
name: string
|
name: string
|
||||||
|
|
@ -49,6 +63,13 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||||
name: 'vectorStoreRetriever',
|
name: 'vectorStoreRetriever',
|
||||||
type: 'BaseRetriever'
|
type: 'BaseRetriever'
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
label: 'Memory',
|
||||||
|
name: 'memory',
|
||||||
|
type: 'DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory',
|
||||||
|
optional: true,
|
||||||
|
description: 'If no memory connected, default BufferMemory will be used'
|
||||||
|
},
|
||||||
{
|
{
|
||||||
label: 'Return Source Documents',
|
label: 'Return Source Documents',
|
||||||
name: 'returnSourceDocuments',
|
name: 'returnSourceDocuments',
|
||||||
|
|
@ -99,6 +120,7 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||||
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
|
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
|
||||||
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
|
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
|
||||||
const chainOption = nodeData.inputs?.chainOption as string
|
const chainOption = nodeData.inputs?.chainOption as string
|
||||||
|
const memory = nodeData.inputs?.memory
|
||||||
|
|
||||||
const obj: any = {
|
const obj: any = {
|
||||||
verbose: process.env.DEBUG === 'true' ? true : false,
|
verbose: process.env.DEBUG === 'true' ? true : false,
|
||||||
|
|
@ -106,15 +128,25 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||||
type: 'stuff',
|
type: 'stuff',
|
||||||
prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template)
|
prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template)
|
||||||
},
|
},
|
||||||
memory: new BufferMemory({
|
questionGeneratorChainOptions: {
|
||||||
|
template: CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
|
||||||
|
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }
|
||||||
|
if (memory) {
|
||||||
|
memory.inputKey = 'question'
|
||||||
|
memory.outputKey = 'text'
|
||||||
|
memory.memoryKey = 'chat_history'
|
||||||
|
obj.memory = memory
|
||||||
|
} else {
|
||||||
|
obj.memory = new BufferMemory({
|
||||||
memoryKey: 'chat_history',
|
memoryKey: 'chat_history',
|
||||||
inputKey: 'question',
|
inputKey: 'question',
|
||||||
outputKey: 'text',
|
outputKey: 'text',
|
||||||
returnMessages: true
|
returnMessages: true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
|
|
||||||
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }
|
|
||||||
|
|
||||||
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj)
|
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj)
|
||||||
return chain
|
return chain
|
||||||
|
|
@ -123,6 +155,8 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||||
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
|
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
|
||||||
const chain = nodeData.instance as ConversationalRetrievalQAChain
|
const chain = nodeData.instance as ConversationalRetrievalQAChain
|
||||||
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
|
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
|
||||||
|
const memory = nodeData.inputs?.memory
|
||||||
|
|
||||||
let model = nodeData.inputs?.model
|
let model = nodeData.inputs?.model
|
||||||
|
|
||||||
// Temporary fix: https://github.com/hwchase17/langchainjs/issues/754
|
// Temporary fix: https://github.com/hwchase17/langchainjs/issues/754
|
||||||
|
|
@ -131,16 +165,17 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||||
|
|
||||||
const obj = { question: input }
|
const obj = { question: input }
|
||||||
|
|
||||||
if (chain.memory && options && options.chatHistory) {
|
// If external memory like Zep, Redis is being used, ignore below
|
||||||
|
if (!memory && chain.memory && options && options.chatHistory) {
|
||||||
const chatHistory = []
|
const chatHistory = []
|
||||||
const histories: IMessage[] = options.chatHistory
|
const histories: IMessage[] = options.chatHistory
|
||||||
const memory = chain.memory as BaseChatMemory
|
const memory = chain.memory as BaseChatMemory
|
||||||
|
|
||||||
for (const message of histories) {
|
for (const message of histories) {
|
||||||
if (message.type === 'apiMessage') {
|
if (message.type === 'apiMessage') {
|
||||||
chatHistory.push(new AIChatMessage(message.message))
|
chatHistory.push(new AIMessage(message.message))
|
||||||
} else if (message.type === 'userMessage') {
|
} else if (message.type === 'userMessage') {
|
||||||
chatHistory.push(new HumanChatMessage(message.message))
|
chatHistory.push(new HumanMessage(message.message))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
memory.chatHistory = new ChatMessageHistory(chatHistory)
|
memory.chatHistory = new ChatMessageHistory(chatHistory)
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,9 @@ class DynamoDb_Memory implements INode {
|
||||||
inputs: INodeParams[]
|
inputs: INodeParams[]
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.label = 'DynamoDB Memory'
|
this.label = 'DynamoDB Chat Memory'
|
||||||
this.name = 'DynamoDbMemory'
|
this.name = 'DynamoDBChatMemory'
|
||||||
|
this.type = 'DynamoDBChatMemory'
|
||||||
this.icon = 'dynamodb.svg'
|
this.icon = 'dynamodb.svg'
|
||||||
this.category = 'Memory'
|
this.category = 'Memory'
|
||||||
this.description = 'Stores the conversation in dynamo db table'
|
this.description = 'Stores the conversation in dynamo db table'
|
||||||
|
|
|
||||||
|
|
@ -168,8 +168,10 @@ export const isValidConnection = (connection, reactFlowInstance) => {
|
||||||
//sourceHandle: "llmChain_0-output-llmChain-BaseChain"
|
//sourceHandle: "llmChain_0-output-llmChain-BaseChain"
|
||||||
//targetHandle: "mrlkAgentLLM_0-input-model-BaseLanguageModel"
|
//targetHandle: "mrlkAgentLLM_0-input-model-BaseLanguageModel"
|
||||||
|
|
||||||
const sourceTypes = sourceHandle.split('-')[sourceHandle.split('-').length - 1].split('|')
|
let sourceTypes = sourceHandle.split('-')[sourceHandle.split('-').length - 1].split('|')
|
||||||
const targetTypes = targetHandle.split('-')[targetHandle.split('-').length - 1].split('|')
|
sourceTypes = sourceTypes.map((s) => s.trim())
|
||||||
|
let targetTypes = targetHandle.split('-')[targetHandle.split('-').length - 1].split('|')
|
||||||
|
targetTypes = targetTypes.map((t) => t.trim())
|
||||||
|
|
||||||
if (targetTypes.some((t) => sourceTypes.includes(t))) {
|
if (targetTypes.some((t) => sourceTypes.includes(t))) {
|
||||||
let targetNode = reactFlowInstance.getNode(target)
|
let targetNode = reactFlowInstance.getNode(target)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue