From 0a36aa7ef4d63aa0b8095f3dcb8377014e90fbfa Mon Sep 17 00:00:00 2001 From: Henry Heng Date: Sat, 17 Aug 2024 19:28:01 +0100 Subject: [PATCH] Feature/Add message history to agents (#3031) add message history to agents --- .../nodes/agents/ToolAgent/ToolAgent.ts | 42 +++++++++- .../ChatPromptTemplate/ChatPromptTemplate.ts | 78 ++++++++++++++++++- .../nodes/sequentialagents/Agent/Agent.ts | 55 +++++++++++-- .../nodes/sequentialagents/LLMNode/LLMNode.ts | 49 +++++++++++- .../nodes/sequentialagents/commonUtils.ts | 32 ++++++++ packages/components/src/utils.ts | 16 ++++ 6 files changed, 256 insertions(+), 16 deletions(-) diff --git a/packages/components/nodes/agents/ToolAgent/ToolAgent.ts b/packages/components/nodes/agents/ToolAgent/ToolAgent.ts index 837efb2a5..c56138da7 100644 --- a/packages/components/nodes/agents/ToolAgent/ToolAgent.ts +++ b/packages/components/nodes/agents/ToolAgent/ToolAgent.ts @@ -7,7 +7,7 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models' import { ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate, PromptTemplate } from '@langchain/core/prompts' import { formatToOpenAIToolMessages } from 'langchain/agents/format_scratchpad/openai_tools' import { type ToolsAgentStep } from 'langchain/agents/openai/output_parser' -import { getBaseClasses } from '../../../src/utils' +import { getBaseClasses, handleEscapeCharacters } from '../../../src/utils' import { FlowiseMemory, ICommonObject, INode, INodeData, INodeParams, IUsedTool, IVisionChatModal } from '../../../src/Interface' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { AgentExecutor, ToolCallingAgentOutputParser } from '../../../src/agents' @@ -31,7 +31,7 @@ class ToolAgent_Agents implements INode { constructor(fields?: { sessionId?: string }) { this.label = 'Tool Agent' this.name = 'toolAgent' - this.version = 1.0 + this.version = 2.0 this.type = 'AgentExecutor' this.category = 'Agents' this.icon = 'toolAgent.png' @@ -56,11 +56,19 @@ class ToolAgent_Agents implements INode { description: 'Only compatible with models that are capable of function calling: ChatOpenAI, ChatMistral, ChatAnthropic, ChatGoogleGenerativeAI, ChatVertexAI, GroqChat' }, + { + label: 'Chat Prompt Template', + name: 'chatPromptTemplate', + type: 'ChatPromptTemplate', + description: 'Override existing prompt with Chat Prompt Template. Human Message must includes {input} variable', + optional: true + }, { label: 'System Message', name: 'systemMessage', type: 'string', default: `You are a helpful AI assistant.`, + description: 'If Chat Prompt Template is provided, this will be ignored', rows: 4, optional: true, additionalParams: true @@ -209,13 +217,38 @@ const prepareAgent = async ( const inputKey = memory.inputKey ? memory.inputKey : 'input' const prependMessages = options?.prependMessages - const prompt = ChatPromptTemplate.fromMessages([ + let prompt = ChatPromptTemplate.fromMessages([ ['system', systemMessage], new MessagesPlaceholder(memoryKey), ['human', `{${inputKey}}`], new MessagesPlaceholder('agent_scratchpad') ]) + let promptVariables = {} + const chatPromptTemplate = nodeData.inputs?.chatPromptTemplate as ChatPromptTemplate + if (chatPromptTemplate && chatPromptTemplate.promptMessages.length) { + const humanPrompt = chatPromptTemplate.promptMessages[chatPromptTemplate.promptMessages.length - 1] + const messages = [ + ...chatPromptTemplate.promptMessages.slice(0, -1), + new MessagesPlaceholder(memoryKey), + humanPrompt, + new MessagesPlaceholder('agent_scratchpad') + ] + prompt = ChatPromptTemplate.fromMessages(messages) + if ((chatPromptTemplate as any).promptValues) { + const promptValuesRaw = (chatPromptTemplate as any).promptValues + const promptValues = handleEscapeCharacters(promptValuesRaw, true) + for (const val in promptValues) { + promptVariables = { + ...promptVariables, + [val]: () => { + return promptValues[val] + } + } + } + } + } + if (llmSupportsVision(model)) { const visionChatModel = model as IVisionChatModal const messageContent = await addImagesToMessages(nodeData, options, model.multiModalOption) @@ -258,7 +291,8 @@ const prepareAgent = async ( [memoryKey]: async (_: { input: string; steps: ToolsAgentStep[] }) => { const messages = (await memory.getChatMessages(flowObj?.sessionId, true, prependMessages)) as BaseMessage[] return messages ?? [] - } + }, + ...promptVariables }, prompt, modelWithTools, diff --git a/packages/components/nodes/prompts/ChatPromptTemplate/ChatPromptTemplate.ts b/packages/components/nodes/prompts/ChatPromptTemplate/ChatPromptTemplate.ts index 5d84b451d..2a6a5b7ad 100644 --- a/packages/components/nodes/prompts/ChatPromptTemplate/ChatPromptTemplate.ts +++ b/packages/components/nodes/prompts/ChatPromptTemplate/ChatPromptTemplate.ts @@ -1,6 +1,33 @@ -import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeParams } from '../../../src/Interface' import { getBaseClasses } from '../../../src/utils' import { ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate } from '@langchain/core/prompts' +import { getVM } from '../../sequentialagents/commonUtils' +import { DataSource } from 'typeorm' +const defaultFunc = `const { AIMessage, HumanMessage, ToolMessage } = require('@langchain/core/messages'); + +return [ + new HumanMessage("What is 333382 🦜 1932?"), + new AIMessage({ + content: "", + tool_calls: [ + { + id: "12345", + name: "calulator", + args: { + number1: 333382, + number2: 1932, + operation: "divide", + }, + }, + ], + }), + new ToolMessage({ + tool_call_id: "12345", + content: "The answer is 172.558.", + }), + new AIMessage("The answer is 172.558."), +]` +const TAB_IDENTIFIER = 'selectedMessagesTab' class ChatPromptTemplate_Prompts implements INode { label: string @@ -16,7 +43,7 @@ class ChatPromptTemplate_Prompts implements INode { constructor() { this.label = 'Chat Prompt Template' this.name = 'chatPromptTemplate' - this.version = 1.0 + this.version = 2.0 this.type = 'ChatPromptTemplate' this.icon = 'prompt.svg' this.category = 'Prompts' @@ -33,6 +60,7 @@ class ChatPromptTemplate_Prompts implements INode { { label: 'Human Message', name: 'humanMessagePrompt', + description: 'This prompt will be added at the end of the messages as human message', type: 'string', rows: 4, placeholder: `{text}` @@ -44,20 +72,62 @@ class ChatPromptTemplate_Prompts implements INode { optional: true, acceptVariable: true, list: true + }, + { + label: 'Messages History', + name: 'messageHistory', + description: 'Add messages after System Message. This is useful when you want to provide few shot examples', + type: 'tabs', + tabIdentifier: TAB_IDENTIFIER, + additionalParams: true, + default: 'messageHistoryCode', + tabs: [ + //TODO: add UI for messageHistory + { + label: 'Add Messages (Code)', + name: 'messageHistoryCode', + type: 'code', + hideCodeExecute: true, + codeExample: defaultFunc, + optional: true, + additionalParams: true + } + ] } ] } - async init(nodeData: INodeData): Promise { + async init(nodeData: INodeData, _: string, options: ICommonObject): Promise { const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string const humanMessagePrompt = nodeData.inputs?.humanMessagePrompt as string const promptValuesStr = nodeData.inputs?.promptValues + const tabIdentifier = nodeData.inputs?.[`${TAB_IDENTIFIER}_${nodeData.id}`] as string + const selectedTab = tabIdentifier ? tabIdentifier.split(`_${nodeData.id}`)[0] : 'messageHistoryCode' + const messageHistoryCode = nodeData.inputs?.messageHistoryCode + const messageHistory = nodeData.inputs?.messageHistory - const prompt = ChatPromptTemplate.fromMessages([ + let prompt = ChatPromptTemplate.fromMessages([ SystemMessagePromptTemplate.fromTemplate(systemMessagePrompt), HumanMessagePromptTemplate.fromTemplate(humanMessagePrompt) ]) + if ((messageHistory && messageHistory === 'messageHistoryCode') || (selectedTab === 'messageHistoryCode' && messageHistoryCode)) { + const appDataSource = options.appDataSource as DataSource + const databaseEntities = options.databaseEntities as IDatabaseEntity + const vm = await getVM(appDataSource, databaseEntities, nodeData, {}) + try { + const response = await vm.run(`module.exports = async function() {${messageHistoryCode}}()`, __dirname) + if (!Array.isArray(response)) throw new Error('Returned message history must be an array') + prompt = ChatPromptTemplate.fromMessages([ + SystemMessagePromptTemplate.fromTemplate(systemMessagePrompt), + ...response, + HumanMessagePromptTemplate.fromTemplate(humanMessagePrompt) + ]) + } catch (e) { + throw new Error(e) + } + } + let promptValues: ICommonObject = {} if (promptValuesStr) { try { diff --git a/packages/components/nodes/sequentialagents/Agent/Agent.ts b/packages/components/nodes/sequentialagents/Agent/Agent.ts index 1116f6d38..f4c47ffe1 100644 --- a/packages/components/nodes/sequentialagents/Agent/Agent.ts +++ b/packages/components/nodes/sequentialagents/Agent/Agent.ts @@ -29,7 +29,8 @@ import { transformObjectPropertyToFunction, restructureMessages, MessagesState, - RunnableCallable + RunnableCallable, + checkMessageHistory } from '../commonUtils' import { END, StateGraph } from '@langchain/langgraph' import { StructuredTool } from '@langchain/core/tools' @@ -149,6 +150,31 @@ const defaultFunc = `const result = $flow.output; return { aggregate: [result.content] };` + +const messageHistoryExample = `const { AIMessage, HumanMessage, ToolMessage } = require('@langchain/core/messages'); + +return [ + new HumanMessage("What is 333382 🦜 1932?"), + new AIMessage({ + content: "", + tool_calls: [ + { + id: "12345", + name: "calulator", + args: { + number1: 333382, + number2: 1932, + operation: "divide", + }, + }, + ], + }), + new ToolMessage({ + tool_call_id: "12345", + content: "The answer is 172.558.", + }), + new AIMessage("The answer is 172.558."), +]` const TAB_IDENTIFIER = 'selectedUpdateStateMemoryTab' class Agent_SeqAgents implements INode { @@ -168,7 +194,7 @@ class Agent_SeqAgents implements INode { constructor() { this.label = 'Agent' this.name = 'seqAgent' - this.version = 2.0 + this.version = 3.0 this.type = 'Agent' this.icon = 'seqAgent.png' this.category = 'Sequential Agents' @@ -199,6 +225,17 @@ class Agent_SeqAgents implements INode { optional: true, additionalParams: true }, + { + label: 'Messages History', + name: 'messageHistory', + description: + 'Return a list of messages between System Prompt and Human Prompt. This is useful when you want to provide few shot examples', + type: 'code', + hideCodeExecute: true, + codeExample: messageHistoryExample, + optional: true, + additionalParams: true + }, { label: 'Tools', name: 'tools', @@ -426,6 +463,8 @@ class Agent_SeqAgents implements INode { llm, interrupt, agent: await createAgent( + nodeData, + options, agentName, state, llm, @@ -515,6 +554,8 @@ class Agent_SeqAgents implements INode { } async function createAgent( + nodeData: INodeData, + options: ICommonObject, agentName: string, state: ISeqAgentsState, llm: BaseChatModel, @@ -535,7 +576,8 @@ async function createAgent( if (systemPrompt) promptArrays.unshift(['system', systemPrompt]) if (humanPrompt) promptArrays.push(['human', humanPrompt]) - const prompt = ChatPromptTemplate.fromMessages(promptArrays) + let prompt = ChatPromptTemplate.fromMessages(promptArrays) + prompt = await checkMessageHistory(nodeData, options, prompt, promptArrays, systemPrompt) if (multiModalMessageContent.length) { const msg = HumanMessagePromptTemplate.fromTemplate([...multiModalMessageContent]) @@ -597,7 +639,9 @@ async function createAgent( if (systemPrompt) promptArrays.unshift(['system', systemPrompt]) if (humanPrompt) promptArrays.push(['human', humanPrompt]) - const prompt = ChatPromptTemplate.fromMessages(promptArrays) + let prompt = ChatPromptTemplate.fromMessages(promptArrays) + prompt = await checkMessageHistory(nodeData, options, prompt, promptArrays, systemPrompt) + if (multiModalMessageContent.length) { const msg = HumanMessagePromptTemplate.fromTemplate([...multiModalMessageContent]) prompt.promptMessages.splice(1, 0, msg) @@ -624,7 +668,8 @@ async function createAgent( if (systemPrompt) promptArrays.unshift(['system', systemPrompt]) if (humanPrompt) promptArrays.push(['human', humanPrompt]) - const prompt = ChatPromptTemplate.fromMessages(promptArrays) + let prompt = ChatPromptTemplate.fromMessages(promptArrays) + prompt = await checkMessageHistory(nodeData, options, prompt, promptArrays, systemPrompt) if (multiModalMessageContent.length) { const msg = HumanMessagePromptTemplate.fromTemplate([...multiModalMessageContent]) diff --git a/packages/components/nodes/sequentialagents/LLMNode/LLMNode.ts b/packages/components/nodes/sequentialagents/LLMNode/LLMNode.ts index 17b5830d9..a5a570064 100644 --- a/packages/components/nodes/sequentialagents/LLMNode/LLMNode.ts +++ b/packages/components/nodes/sequentialagents/LLMNode/LLMNode.ts @@ -25,7 +25,8 @@ import { getVM, processImageMessage, transformObjectPropertyToFunction, - restructureMessages + restructureMessages, + checkMessageHistory } from '../commonUtils' import { ChatGoogleGenerativeAI } from '../../chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI' @@ -130,6 +131,31 @@ return { aggregate: [result.content] };` +const messageHistoryExample = `const { AIMessage, HumanMessage, ToolMessage } = require('@langchain/core/messages'); + +return [ + new HumanMessage("What is 333382 🦜 1932?"), + new AIMessage({ + content: "", + tool_calls: [ + { + id: "12345", + name: "calulator", + args: { + number1: 333382, + number2: 1932, + operation: "divide", + }, + }, + ], + }), + new ToolMessage({ + tool_call_id: "12345", + content: "The answer is 172.558.", + }), + new AIMessage("The answer is 172.558."), +]` + class LLMNode_SeqAgents implements INode { label: string name: string @@ -147,7 +173,7 @@ class LLMNode_SeqAgents implements INode { constructor() { this.label = 'LLM Node' this.name = 'seqLLMNode' - this.version = 2.0 + this.version = 3.0 this.type = 'LLMNode' this.icon = 'llmNode.svg' this.category = 'Sequential Agents' @@ -178,6 +204,17 @@ class LLMNode_SeqAgents implements INode { optional: true, additionalParams: true }, + { + label: 'Messages History', + name: 'messageHistory', + description: + 'Return a list of messages between System Prompt and Human Prompt. This is useful when you want to provide few shot examples', + type: 'code', + hideCodeExecute: true, + codeExample: messageHistoryExample, + optional: true, + additionalParams: true + }, { label: 'Start | Agent | Condition | LLM | Tool Node', name: 'sequentialNode', @@ -355,6 +392,8 @@ class LLMNode_SeqAgents implements INode { state, llm, agent: await createAgent( + nodeData, + options, llmNodeName, state, bindModel || llm, @@ -394,6 +433,8 @@ class LLMNode_SeqAgents implements INode { } async function createAgent( + nodeData: INodeData, + options: ICommonObject, llmNodeName: string, state: ISeqAgentsState, llm: BaseChatModel, @@ -438,7 +479,9 @@ async function createAgent( if (systemPrompt) promptArrays.unshift(['system', systemPrompt]) if (humanPrompt) promptArrays.push(['human', humanPrompt]) - const prompt = ChatPromptTemplate.fromMessages(promptArrays) + let prompt = ChatPromptTemplate.fromMessages(promptArrays) + prompt = await checkMessageHistory(nodeData, options, prompt, promptArrays, systemPrompt) + if (multiModalMessageContent.length) { const msg = HumanMessagePromptTemplate.fromTemplate([...multiModalMessageContent]) prompt.promptMessages.splice(1, 0, msg) diff --git a/packages/components/nodes/sequentialagents/commonUtils.ts b/packages/components/nodes/sequentialagents/commonUtils.ts index b68de0bc4..f8cae0b85 100644 --- a/packages/components/nodes/sequentialagents/commonUtils.ts +++ b/packages/components/nodes/sequentialagents/commonUtils.ts @@ -11,6 +11,7 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models' import { addImagesToMessages, llmSupportsVision } from '../../src/multiModalUtils' import { ICommonObject, IDatabaseEntity, INodeData, ISeqAgentsState, IVisionChatModal } from '../../src/Interface' import { availableDependencies, defaultAllowBuiltInDep, getVars, prepareSandboxVars } from '../../src/utils' +import { ChatPromptTemplate, BaseMessagePromptTemplateLike } from '@langchain/core/prompts' export const checkCondition = (input: string | number | undefined, condition: string, value: string | number = ''): boolean => { if (!input && condition === 'Is Empty') return true @@ -344,3 +345,34 @@ export class RunnableCallable extends Runnable { return returnValue } } + +export const checkMessageHistory = async ( + nodeData: INodeData, + options: ICommonObject, + prompt: ChatPromptTemplate, + promptArrays: BaseMessagePromptTemplateLike[], + sysPrompt: string +) => { + const messageHistory = nodeData.inputs?.messageHistory + + if (messageHistory) { + const appDataSource = options.appDataSource as DataSource + const databaseEntities = options.databaseEntities as IDatabaseEntity + const vm = await getVM(appDataSource, databaseEntities, nodeData, {}) + try { + const response = await vm.run(`module.exports = async function() {${messageHistory}}()`, __dirname) + if (!Array.isArray(response)) throw new Error('Returned message history must be an array') + if (sysPrompt) { + // insert at index 1 + promptArrays.splice(1, 0, ...response) + } else { + promptArrays.unshift(...response) + } + prompt = ChatPromptTemplate.fromMessages(promptArrays) + } catch (e) { + throw new Error(e) + } + } + + return prompt +} diff --git a/packages/components/src/utils.ts b/packages/components/src/utils.ts index 57ac493c8..2b5d68a2b 100644 --- a/packages/components/src/utils.ts +++ b/packages/components/src/utils.ts @@ -26,6 +26,22 @@ export const availableDependencies = [ '@google-ai/generativelanguage', '@google/generative-ai', '@huggingface/inference', + '@langchain/anthropic', + '@langchain/aws', + '@langchain/cohere', + '@langchain/community', + '@langchain/core', + '@langchain/google-genai', + '@langchain/google-vertexai', + '@langchain/groq', + '@langchain/langgraph', + '@langchain/mistralai', + '@langchain/mongodb', + '@langchain/ollama', + '@langchain/openai', + '@langchain/pinecone', + '@langchain/qdrant', + '@langchain/weaviate', '@notionhq/client', '@opensearch-project/opensearch', '@pinecone-database/pinecone',