Feature/Add message history to agents (#3031)

add message history to agents
This commit is contained in:
Henry Heng 2024-08-17 19:28:01 +01:00 committed by GitHub
parent 36c6c6425c
commit 0a36aa7ef4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 256 additions and 16 deletions

View File

@ -7,7 +7,7 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'
import { ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate, PromptTemplate } from '@langchain/core/prompts'
import { formatToOpenAIToolMessages } from 'langchain/agents/format_scratchpad/openai_tools'
import { type ToolsAgentStep } from 'langchain/agents/openai/output_parser'
import { getBaseClasses } from '../../../src/utils'
import { getBaseClasses, handleEscapeCharacters } from '../../../src/utils'
import { FlowiseMemory, ICommonObject, INode, INodeData, INodeParams, IUsedTool, IVisionChatModal } from '../../../src/Interface'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { AgentExecutor, ToolCallingAgentOutputParser } from '../../../src/agents'
@ -31,7 +31,7 @@ class ToolAgent_Agents implements INode {
constructor(fields?: { sessionId?: string }) {
this.label = 'Tool Agent'
this.name = 'toolAgent'
this.version = 1.0
this.version = 2.0
this.type = 'AgentExecutor'
this.category = 'Agents'
this.icon = 'toolAgent.png'
@ -56,11 +56,19 @@ class ToolAgent_Agents implements INode {
description:
'Only compatible with models that are capable of function calling: ChatOpenAI, ChatMistral, ChatAnthropic, ChatGoogleGenerativeAI, ChatVertexAI, GroqChat'
},
{
label: 'Chat Prompt Template',
name: 'chatPromptTemplate',
type: 'ChatPromptTemplate',
description: 'Override existing prompt with Chat Prompt Template. Human Message must includes {input} variable',
optional: true
},
{
label: 'System Message',
name: 'systemMessage',
type: 'string',
default: `You are a helpful AI assistant.`,
description: 'If Chat Prompt Template is provided, this will be ignored',
rows: 4,
optional: true,
additionalParams: true
@ -209,13 +217,38 @@ const prepareAgent = async (
const inputKey = memory.inputKey ? memory.inputKey : 'input'
const prependMessages = options?.prependMessages
const prompt = ChatPromptTemplate.fromMessages([
let prompt = ChatPromptTemplate.fromMessages([
['system', systemMessage],
new MessagesPlaceholder(memoryKey),
['human', `{${inputKey}}`],
new MessagesPlaceholder('agent_scratchpad')
])
let promptVariables = {}
const chatPromptTemplate = nodeData.inputs?.chatPromptTemplate as ChatPromptTemplate
if (chatPromptTemplate && chatPromptTemplate.promptMessages.length) {
const humanPrompt = chatPromptTemplate.promptMessages[chatPromptTemplate.promptMessages.length - 1]
const messages = [
...chatPromptTemplate.promptMessages.slice(0, -1),
new MessagesPlaceholder(memoryKey),
humanPrompt,
new MessagesPlaceholder('agent_scratchpad')
]
prompt = ChatPromptTemplate.fromMessages(messages)
if ((chatPromptTemplate as any).promptValues) {
const promptValuesRaw = (chatPromptTemplate as any).promptValues
const promptValues = handleEscapeCharacters(promptValuesRaw, true)
for (const val in promptValues) {
promptVariables = {
...promptVariables,
[val]: () => {
return promptValues[val]
}
}
}
}
}
if (llmSupportsVision(model)) {
const visionChatModel = model as IVisionChatModal
const messageContent = await addImagesToMessages(nodeData, options, model.multiModalOption)
@ -258,7 +291,8 @@ const prepareAgent = async (
[memoryKey]: async (_: { input: string; steps: ToolsAgentStep[] }) => {
const messages = (await memory.getChatMessages(flowObj?.sessionId, true, prependMessages)) as BaseMessage[]
return messages ?? []
}
},
...promptVariables
},
prompt,
modelWithTools,

View File

@ -1,6 +1,33 @@
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeParams } from '../../../src/Interface'
import { getBaseClasses } from '../../../src/utils'
import { ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate } from '@langchain/core/prompts'
import { getVM } from '../../sequentialagents/commonUtils'
import { DataSource } from 'typeorm'
const defaultFunc = `const { AIMessage, HumanMessage, ToolMessage } = require('@langchain/core/messages');
return [
new HumanMessage("What is 333382 🦜 1932?"),
new AIMessage({
content: "",
tool_calls: [
{
id: "12345",
name: "calulator",
args: {
number1: 333382,
number2: 1932,
operation: "divide",
},
},
],
}),
new ToolMessage({
tool_call_id: "12345",
content: "The answer is 172.558.",
}),
new AIMessage("The answer is 172.558."),
]`
const TAB_IDENTIFIER = 'selectedMessagesTab'
class ChatPromptTemplate_Prompts implements INode {
label: string
@ -16,7 +43,7 @@ class ChatPromptTemplate_Prompts implements INode {
constructor() {
this.label = 'Chat Prompt Template'
this.name = 'chatPromptTemplate'
this.version = 1.0
this.version = 2.0
this.type = 'ChatPromptTemplate'
this.icon = 'prompt.svg'
this.category = 'Prompts'
@ -33,6 +60,7 @@ class ChatPromptTemplate_Prompts implements INode {
{
label: 'Human Message',
name: 'humanMessagePrompt',
description: 'This prompt will be added at the end of the messages as human message',
type: 'string',
rows: 4,
placeholder: `{text}`
@ -44,20 +72,62 @@ class ChatPromptTemplate_Prompts implements INode {
optional: true,
acceptVariable: true,
list: true
},
{
label: 'Messages History',
name: 'messageHistory',
description: 'Add messages after System Message. This is useful when you want to provide few shot examples',
type: 'tabs',
tabIdentifier: TAB_IDENTIFIER,
additionalParams: true,
default: 'messageHistoryCode',
tabs: [
//TODO: add UI for messageHistory
{
label: 'Add Messages (Code)',
name: 'messageHistoryCode',
type: 'code',
hideCodeExecute: true,
codeExample: defaultFunc,
optional: true,
additionalParams: true
}
]
}
]
}
async init(nodeData: INodeData): Promise<any> {
async init(nodeData: INodeData, _: string, options: ICommonObject): Promise<any> {
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
const humanMessagePrompt = nodeData.inputs?.humanMessagePrompt as string
const promptValuesStr = nodeData.inputs?.promptValues
const tabIdentifier = nodeData.inputs?.[`${TAB_IDENTIFIER}_${nodeData.id}`] as string
const selectedTab = tabIdentifier ? tabIdentifier.split(`_${nodeData.id}`)[0] : 'messageHistoryCode'
const messageHistoryCode = nodeData.inputs?.messageHistoryCode
const messageHistory = nodeData.inputs?.messageHistory
const prompt = ChatPromptTemplate.fromMessages([
let prompt = ChatPromptTemplate.fromMessages([
SystemMessagePromptTemplate.fromTemplate(systemMessagePrompt),
HumanMessagePromptTemplate.fromTemplate(humanMessagePrompt)
])
if ((messageHistory && messageHistory === 'messageHistoryCode') || (selectedTab === 'messageHistoryCode' && messageHistoryCode)) {
const appDataSource = options.appDataSource as DataSource
const databaseEntities = options.databaseEntities as IDatabaseEntity
const vm = await getVM(appDataSource, databaseEntities, nodeData, {})
try {
const response = await vm.run(`module.exports = async function() {${messageHistoryCode}}()`, __dirname)
if (!Array.isArray(response)) throw new Error('Returned message history must be an array')
prompt = ChatPromptTemplate.fromMessages([
SystemMessagePromptTemplate.fromTemplate(systemMessagePrompt),
...response,
HumanMessagePromptTemplate.fromTemplate(humanMessagePrompt)
])
} catch (e) {
throw new Error(e)
}
}
let promptValues: ICommonObject = {}
if (promptValuesStr) {
try {

View File

@ -29,7 +29,8 @@ import {
transformObjectPropertyToFunction,
restructureMessages,
MessagesState,
RunnableCallable
RunnableCallable,
checkMessageHistory
} from '../commonUtils'
import { END, StateGraph } from '@langchain/langgraph'
import { StructuredTool } from '@langchain/core/tools'
@ -149,6 +150,31 @@ const defaultFunc = `const result = $flow.output;
return {
aggregate: [result.content]
};`
const messageHistoryExample = `const { AIMessage, HumanMessage, ToolMessage } = require('@langchain/core/messages');
return [
new HumanMessage("What is 333382 🦜 1932?"),
new AIMessage({
content: "",
tool_calls: [
{
id: "12345",
name: "calulator",
args: {
number1: 333382,
number2: 1932,
operation: "divide",
},
},
],
}),
new ToolMessage({
tool_call_id: "12345",
content: "The answer is 172.558.",
}),
new AIMessage("The answer is 172.558."),
]`
const TAB_IDENTIFIER = 'selectedUpdateStateMemoryTab'
class Agent_SeqAgents implements INode {
@ -168,7 +194,7 @@ class Agent_SeqAgents implements INode {
constructor() {
this.label = 'Agent'
this.name = 'seqAgent'
this.version = 2.0
this.version = 3.0
this.type = 'Agent'
this.icon = 'seqAgent.png'
this.category = 'Sequential Agents'
@ -199,6 +225,17 @@ class Agent_SeqAgents implements INode {
optional: true,
additionalParams: true
},
{
label: 'Messages History',
name: 'messageHistory',
description:
'Return a list of messages between System Prompt and Human Prompt. This is useful when you want to provide few shot examples',
type: 'code',
hideCodeExecute: true,
codeExample: messageHistoryExample,
optional: true,
additionalParams: true
},
{
label: 'Tools',
name: 'tools',
@ -426,6 +463,8 @@ class Agent_SeqAgents implements INode {
llm,
interrupt,
agent: await createAgent(
nodeData,
options,
agentName,
state,
llm,
@ -515,6 +554,8 @@ class Agent_SeqAgents implements INode {
}
async function createAgent(
nodeData: INodeData,
options: ICommonObject,
agentName: string,
state: ISeqAgentsState,
llm: BaseChatModel,
@ -535,7 +576,8 @@ async function createAgent(
if (systemPrompt) promptArrays.unshift(['system', systemPrompt])
if (humanPrompt) promptArrays.push(['human', humanPrompt])
const prompt = ChatPromptTemplate.fromMessages(promptArrays)
let prompt = ChatPromptTemplate.fromMessages(promptArrays)
prompt = await checkMessageHistory(nodeData, options, prompt, promptArrays, systemPrompt)
if (multiModalMessageContent.length) {
const msg = HumanMessagePromptTemplate.fromTemplate([...multiModalMessageContent])
@ -597,7 +639,9 @@ async function createAgent(
if (systemPrompt) promptArrays.unshift(['system', systemPrompt])
if (humanPrompt) promptArrays.push(['human', humanPrompt])
const prompt = ChatPromptTemplate.fromMessages(promptArrays)
let prompt = ChatPromptTemplate.fromMessages(promptArrays)
prompt = await checkMessageHistory(nodeData, options, prompt, promptArrays, systemPrompt)
if (multiModalMessageContent.length) {
const msg = HumanMessagePromptTemplate.fromTemplate([...multiModalMessageContent])
prompt.promptMessages.splice(1, 0, msg)
@ -624,7 +668,8 @@ async function createAgent(
if (systemPrompt) promptArrays.unshift(['system', systemPrompt])
if (humanPrompt) promptArrays.push(['human', humanPrompt])
const prompt = ChatPromptTemplate.fromMessages(promptArrays)
let prompt = ChatPromptTemplate.fromMessages(promptArrays)
prompt = await checkMessageHistory(nodeData, options, prompt, promptArrays, systemPrompt)
if (multiModalMessageContent.length) {
const msg = HumanMessagePromptTemplate.fromTemplate([...multiModalMessageContent])

View File

@ -25,7 +25,8 @@ import {
getVM,
processImageMessage,
transformObjectPropertyToFunction,
restructureMessages
restructureMessages,
checkMessageHistory
} from '../commonUtils'
import { ChatGoogleGenerativeAI } from '../../chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI'
@ -130,6 +131,31 @@ return {
aggregate: [result.content]
};`
const messageHistoryExample = `const { AIMessage, HumanMessage, ToolMessage } = require('@langchain/core/messages');
return [
new HumanMessage("What is 333382 🦜 1932?"),
new AIMessage({
content: "",
tool_calls: [
{
id: "12345",
name: "calulator",
args: {
number1: 333382,
number2: 1932,
operation: "divide",
},
},
],
}),
new ToolMessage({
tool_call_id: "12345",
content: "The answer is 172.558.",
}),
new AIMessage("The answer is 172.558."),
]`
class LLMNode_SeqAgents implements INode {
label: string
name: string
@ -147,7 +173,7 @@ class LLMNode_SeqAgents implements INode {
constructor() {
this.label = 'LLM Node'
this.name = 'seqLLMNode'
this.version = 2.0
this.version = 3.0
this.type = 'LLMNode'
this.icon = 'llmNode.svg'
this.category = 'Sequential Agents'
@ -178,6 +204,17 @@ class LLMNode_SeqAgents implements INode {
optional: true,
additionalParams: true
},
{
label: 'Messages History',
name: 'messageHistory',
description:
'Return a list of messages between System Prompt and Human Prompt. This is useful when you want to provide few shot examples',
type: 'code',
hideCodeExecute: true,
codeExample: messageHistoryExample,
optional: true,
additionalParams: true
},
{
label: 'Start | Agent | Condition | LLM | Tool Node',
name: 'sequentialNode',
@ -355,6 +392,8 @@ class LLMNode_SeqAgents implements INode {
state,
llm,
agent: await createAgent(
nodeData,
options,
llmNodeName,
state,
bindModel || llm,
@ -394,6 +433,8 @@ class LLMNode_SeqAgents implements INode {
}
async function createAgent(
nodeData: INodeData,
options: ICommonObject,
llmNodeName: string,
state: ISeqAgentsState,
llm: BaseChatModel,
@ -438,7 +479,9 @@ async function createAgent(
if (systemPrompt) promptArrays.unshift(['system', systemPrompt])
if (humanPrompt) promptArrays.push(['human', humanPrompt])
const prompt = ChatPromptTemplate.fromMessages(promptArrays)
let prompt = ChatPromptTemplate.fromMessages(promptArrays)
prompt = await checkMessageHistory(nodeData, options, prompt, promptArrays, systemPrompt)
if (multiModalMessageContent.length) {
const msg = HumanMessagePromptTemplate.fromTemplate([...multiModalMessageContent])
prompt.promptMessages.splice(1, 0, msg)

View File

@ -11,6 +11,7 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'
import { addImagesToMessages, llmSupportsVision } from '../../src/multiModalUtils'
import { ICommonObject, IDatabaseEntity, INodeData, ISeqAgentsState, IVisionChatModal } from '../../src/Interface'
import { availableDependencies, defaultAllowBuiltInDep, getVars, prepareSandboxVars } from '../../src/utils'
import { ChatPromptTemplate, BaseMessagePromptTemplateLike } from '@langchain/core/prompts'
export const checkCondition = (input: string | number | undefined, condition: string, value: string | number = ''): boolean => {
if (!input && condition === 'Is Empty') return true
@ -344,3 +345,34 @@ export class RunnableCallable<I = unknown, O = unknown> extends Runnable<I, O> {
return returnValue
}
}
export const checkMessageHistory = async (
nodeData: INodeData,
options: ICommonObject,
prompt: ChatPromptTemplate,
promptArrays: BaseMessagePromptTemplateLike[],
sysPrompt: string
) => {
const messageHistory = nodeData.inputs?.messageHistory
if (messageHistory) {
const appDataSource = options.appDataSource as DataSource
const databaseEntities = options.databaseEntities as IDatabaseEntity
const vm = await getVM(appDataSource, databaseEntities, nodeData, {})
try {
const response = await vm.run(`module.exports = async function() {${messageHistory}}()`, __dirname)
if (!Array.isArray(response)) throw new Error('Returned message history must be an array')
if (sysPrompt) {
// insert at index 1
promptArrays.splice(1, 0, ...response)
} else {
promptArrays.unshift(...response)
}
prompt = ChatPromptTemplate.fromMessages(promptArrays)
} catch (e) {
throw new Error(e)
}
}
return prompt
}

View File

@ -26,6 +26,22 @@ export const availableDependencies = [
'@google-ai/generativelanguage',
'@google/generative-ai',
'@huggingface/inference',
'@langchain/anthropic',
'@langchain/aws',
'@langchain/cohere',
'@langchain/community',
'@langchain/core',
'@langchain/google-genai',
'@langchain/google-vertexai',
'@langchain/groq',
'@langchain/langgraph',
'@langchain/mistralai',
'@langchain/mongodb',
'@langchain/ollama',
'@langchain/openai',
'@langchain/pinecone',
'@langchain/qdrant',
'@langchain/weaviate',
'@notionhq/client',
'@opensearch-project/opensearch',
'@pinecone-database/pinecone',