diff --git a/packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts b/packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts index c019ca5a9..63749fd2e 100644 --- a/packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts +++ b/packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts @@ -1,10 +1,18 @@ import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' -import { initializeAgentExecutorWithOptions, AgentExecutor } from 'langchain/agents' -import { getBaseClasses, mapChatHistory } from '../../../src/utils' -import { BaseLanguageModel } from 'langchain/base_language' +import { AgentExecutor, AgentExecutorInput } from 'langchain/agents' +import { ChainValues, AgentStep, AgentFinish, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema' +import { OutputParserException } from 'langchain/schema/output_parser' +import { CallbackManagerForChainRun } from 'langchain/callbacks' +import { formatToOpenAIFunction } from 'langchain/tools' +import { ToolInputParsingException, Tool } from '@langchain/core/tools' +import { getBaseClasses } from '../../../src/utils' import { flatten } from 'lodash' import { BaseChatMemory } from 'langchain/memory' +import { RunnableSequence } from 'langchain/schema/runnable' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' +import { ChatPromptTemplate, MessagesPlaceholder } from 'langchain/prompts' +import { ChatOpenAI } from 'langchain/chat_models/openai' +import { OpenAIFunctionsAgentOutputParser } from 'langchain/agents/openai/output_parser' class OpenAIFunctionAgent_Agents implements INode { label: string @@ -16,8 +24,9 @@ class OpenAIFunctionAgent_Agents implements INode { category: string baseClasses: string[] inputs: INodeParams[] + sessionId?: string - constructor() { + constructor(fields: { sessionId?: string }) { this.label = 'OpenAI Function Agent' this.name = 'openAIFunctionAgent' this.version = 3.0 @@ -52,54 +61,323 @@ class OpenAIFunctionAgent_Agents implements INode { additionalParams: true } ] + this.sessionId = fields?.sessionId } async init(nodeData: INodeData): Promise { - const model = nodeData.inputs?.model as BaseLanguageModel const memory = nodeData.inputs?.memory as BaseChatMemory - const systemMessage = nodeData.inputs?.systemMessage as string - let tools = nodeData.inputs?.tools - tools = flatten(tools) - - const executor = await initializeAgentExecutorWithOptions(tools, model, { - agentType: 'openai-functions', - verbose: process.env.DEBUG === 'true' ? true : false, - agentArgs: { - prefix: systemMessage ?? `You are a helpful AI assistant.` - } - }) + const executor = prepareAgent(nodeData, this.sessionId) if (memory) executor.memory = memory return executor } async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { - const executor = nodeData.instance as AgentExecutor - const memory = nodeData.inputs?.memory as BaseChatMemory + const memory = nodeData.inputs?.memory - if (options && options.chatHistory) { - const chatHistoryClassName = memory.chatHistory.constructor.name - // Only replace when its In-Memory - if (chatHistoryClassName && chatHistoryClassName === 'ChatMessageHistory') { - memory.chatHistory = mapChatHistory(options) - executor.memory = memory - } - } - - ;(executor.memory as any).returnMessages = true // Return true for BaseChatModel + const executor = prepareAgent(nodeData, this.sessionId) const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) + let res: ChainValues = {} + if (options.socketIO && options.socketIOClientId) { const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) - const result = await executor.run(input, [loggerHandler, handler, ...callbacks]) - return result + res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] }) } else { - const result = await executor.run(input, [loggerHandler, ...callbacks]) - return result + res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) } + + await memory.addChatMessages( + [ + { + text: input, + type: 'userMessage' + }, + { + text: res?.output, + type: 'apiMessage' + } + ], + this.sessionId + ) + + return res?.output + } +} + +const formatAgentSteps = (steps: AgentStep[]): BaseMessage[] => + steps.flatMap(({ action, observation }) => { + if ('messageLog' in action && action.messageLog !== undefined) { + const log = action.messageLog as BaseMessage[] + return log.concat(new FunctionMessage(observation, action.tool)) + } else { + return [new AIMessage(action.log)] + } + }) + +const prepareAgent = (nodeData: INodeData, sessionId?: string) => { + const model = nodeData.inputs?.model as ChatOpenAI + const memory = nodeData.inputs?.memory + const systemMessage = nodeData.inputs?.systemMessage as string + let tools = nodeData.inputs?.tools + tools = flatten(tools) + const memoryKey = memory.memoryKey ?? 'chat_history' + + const prompt = ChatPromptTemplate.fromMessages([ + ['ai', systemMessage ?? `You are a helpful AI assistant.`], + new MessagesPlaceholder(memoryKey), + ['human', '{input}'], + new MessagesPlaceholder('agent_scratchpad') + ]) + + const modelWithFunctions = model.bind({ + functions: [...tools.map((tool: any) => formatToOpenAIFunction(tool))] + }) + + const runnableAgent = RunnableSequence.from([ + { + input: (i: { input: string; steps: AgentStep[] }) => i.input, + agent_scratchpad: (i: { input: string; steps: AgentStep[] }) => formatAgentSteps(i.steps), + [memoryKey]: async (_: { input: string; steps: AgentStep[] }) => { + const messages: BaseMessage[] = await memory.getChatMessages(sessionId, true) + return messages ?? [] + } + }, + prompt, + modelWithFunctions, + new OpenAIFunctionsAgentOutputParser() + ]) + + const executor = AgentExecutorExtended.fromAgentAndTools({ + agent: runnableAgent, + tools, + sessionId + }) + + return executor +} + +type AgentExecutorOutput = ChainValues + +class AgentExecutorExtended extends AgentExecutor { + sessionId?: string + + static fromAgentAndTools(fields: AgentExecutorInput & { sessionId?: string }): AgentExecutorExtended { + const newInstance = new AgentExecutorExtended(fields) + if (fields.sessionId) newInstance.sessionId = fields.sessionId + return newInstance + } + + shouldContinueIteration(iterations: number): boolean { + return this.maxIterations === undefined || iterations < this.maxIterations + } + + async _call(inputs: ChainValues, runManager?: CallbackManagerForChainRun): Promise { + const toolsByName = Object.fromEntries(this.tools.map((t) => [t.name.toLowerCase(), t])) + + const steps: AgentStep[] = [] + let iterations = 0 + + const getOutput = async (finishStep: AgentFinish): Promise => { + const { returnValues } = finishStep + const additional = await this.agent.prepareForOutput(returnValues, steps) + + if (this.returnIntermediateSteps) { + return { ...returnValues, intermediateSteps: steps, ...additional } + } + await runManager?.handleAgentEnd(finishStep) + return { ...returnValues, ...additional } + } + + while (this.shouldContinueIteration(iterations)) { + let output + try { + output = await this.agent.plan(steps, inputs, runManager?.getChild()) + } catch (e) { + if (e instanceof OutputParserException) { + let observation + let text = e.message + if (this.handleParsingErrors === true) { + if (e.sendToLLM) { + observation = e.observation + text = e.llmOutput ?? '' + } else { + observation = 'Invalid or incomplete response' + } + } else if (typeof this.handleParsingErrors === 'string') { + observation = this.handleParsingErrors + } else if (typeof this.handleParsingErrors === 'function') { + observation = this.handleParsingErrors(e) + } else { + throw e + } + output = { + tool: '_Exception', + toolInput: observation, + log: text + } as AgentAction + } else { + throw e + } + } + // Check if the agent has finished + if ('returnValues' in output) { + return getOutput(output) + } + + let actions: AgentAction[] + if (Array.isArray(output)) { + actions = output as AgentAction[] + } else { + actions = [output as AgentAction] + } + + const newSteps = await Promise.all( + actions.map(async (action) => { + await runManager?.handleAgentAction(action) + const tool = action.tool === '_Exception' ? new ExceptionTool() : toolsByName[action.tool?.toLowerCase()] + let observation + try { + // here we need to override Tool call method to include sessionId as parameter + observation = tool + ? // @ts-ignore + await tool.call(action.toolInput, runManager?.getChild(), undefined, this.sessionId) + : `${action.tool} is not a valid tool, try another one.` + } catch (e) { + if (e instanceof ToolInputParsingException) { + if (this.handleParsingErrors === true) { + observation = 'Invalid or incomplete tool input. Please try again.' + } else if (typeof this.handleParsingErrors === 'string') { + observation = this.handleParsingErrors + } else if (typeof this.handleParsingErrors === 'function') { + observation = this.handleParsingErrors(e) + } else { + throw e + } + observation = await new ExceptionTool().call(observation, runManager?.getChild()) + return { action, observation: observation ?? '' } + } + } + return { action, observation: observation ?? '' } + }) + ) + + steps.push(...newSteps) + + const lastStep = steps[steps.length - 1] + const lastTool = toolsByName[lastStep.action.tool?.toLowerCase()] + + if (lastTool?.returnDirect) { + return getOutput({ + returnValues: { [this.agent.returnValues[0]]: lastStep.observation }, + log: '' + }) + } + + iterations += 1 + } + + const finish = await this.agent.returnStoppedResponse(this.earlyStoppingMethod, steps, inputs) + + return getOutput(finish) + } + + async _takeNextStep( + nameToolMap: Record, + inputs: ChainValues, + intermediateSteps: AgentStep[], + runManager?: CallbackManagerForChainRun + ): Promise { + let output + try { + output = await this.agent.plan(intermediateSteps, inputs, runManager?.getChild()) + } catch (e) { + if (e instanceof OutputParserException) { + let observation + let text = e.message + if (this.handleParsingErrors === true) { + if (e.sendToLLM) { + observation = e.observation + text = e.llmOutput ?? '' + } else { + observation = 'Invalid or incomplete response' + } + } else if (typeof this.handleParsingErrors === 'string') { + observation = this.handleParsingErrors + } else if (typeof this.handleParsingErrors === 'function') { + observation = this.handleParsingErrors(e) + } else { + throw e + } + output = { + tool: '_Exception', + toolInput: observation, + log: text + } as AgentAction + } else { + throw e + } + } + + if ('returnValues' in output) { + return output + } + + let actions: AgentAction[] + if (Array.isArray(output)) { + actions = output as AgentAction[] + } else { + actions = [output as AgentAction] + } + + const result: AgentStep[] = [] + for (const agentAction of actions) { + let observation = '' + if (runManager) { + await runManager?.handleAgentAction(agentAction) + } + if (agentAction.tool in nameToolMap) { + const tool = nameToolMap[agentAction.tool] + try { + // here we need to override Tool call method to include sessionId as parameter + // @ts-ignore + observation = await tool.call(agentAction.toolInput, runManager?.getChild(), undefined, this.sessionId) + } catch (e) { + if (e instanceof ToolInputParsingException) { + if (this.handleParsingErrors === true) { + observation = 'Invalid or incomplete tool input. Please try again.' + } else if (typeof this.handleParsingErrors === 'string') { + observation = this.handleParsingErrors + } else if (typeof this.handleParsingErrors === 'function') { + observation = this.handleParsingErrors(e) + } else { + throw e + } + observation = await new ExceptionTool().call(observation, runManager?.getChild()) + } + } + } else { + observation = `${agentAction.tool} is not a valid tool, try another available tool: ${Object.keys(nameToolMap).join(', ')}` + } + result.push({ + action: agentAction, + observation + }) + } + return result + } +} + +class ExceptionTool extends Tool { + name = '_Exception' + + description = 'Exception tool' + + async _call(query: string) { + return query } } diff --git a/packages/components/nodes/memory/RedisBackedChatMemory/RedisBackedChatMemory.ts b/packages/components/nodes/memory/RedisBackedChatMemory/RedisBackedChatMemory.ts index d6ec9a114..72af1cb5d 100644 --- a/packages/components/nodes/memory/RedisBackedChatMemory/RedisBackedChatMemory.ts +++ b/packages/components/nodes/memory/RedisBackedChatMemory/RedisBackedChatMemory.ts @@ -1,8 +1,14 @@ -import { INode, INodeData, INodeParams, ICommonObject } from '../../../src/Interface' -import { getBaseClasses, getCredentialData, getCredentialParam, serializeChatHistory } from '../../../src/utils' +import { INode, INodeData, INodeParams, ICommonObject, IMessage, MessageType } from '../../../src/Interface' +import { + convertBaseMessagetoIMessage, + getBaseClasses, + getCredentialData, + getCredentialParam, + serializeChatHistory +} from '../../../src/utils' import { BufferMemory, BufferMemoryInput } from 'langchain/memory' import { RedisChatMessageHistory, RedisChatMessageHistoryInput } from 'langchain/stores/message/ioredis' -import { mapStoredMessageToChatMessage, BaseMessage } from 'langchain/schema' +import { mapStoredMessageToChatMessage, BaseMessage, AIMessage, HumanMessage } from 'langchain/schema' import { Redis } from 'ioredis' class RedisBackedChatMemory_Memory implements INode { @@ -94,14 +100,20 @@ class RedisBackedChatMemory_Memory implements INode { } const initalizeRedis = async (nodeData: INodeData, options: ICommonObject): Promise => { - const sessionId = nodeData.inputs?.sessionId as string const sessionTTL = nodeData.inputs?.sessionTTL as number const memoryKey = nodeData.inputs?.memoryKey as string const windowSize = nodeData.inputs?.windowSize as number const chatId = options?.chatId as string let isSessionIdUsingChatMessageId = false - if (!sessionId && chatId) isSessionIdUsingChatMessageId = true + let sessionId = '' + + if (!nodeData.inputs?.sessionId && chatId) { + isSessionIdUsingChatMessageId = true + sessionId = chatId + } else { + sessionId = nodeData.inputs?.sessionId + } const credentialData = await getCredentialData(nodeData.credential ?? '', options) const redisUrl = getCredentialParam('redisUrl', credentialData, nodeData) @@ -128,7 +140,7 @@ const initalizeRedis = async (nodeData: INodeData, options: ICommonObject): Prom } let obj: RedisChatMessageHistoryInput = { - sessionId: sessionId ? sessionId : chatId, + sessionId, client } @@ -162,21 +174,67 @@ const initalizeRedis = async (nodeData: INodeData, options: ICommonObject): Prom const memory = new BufferMemoryExtended({ memoryKey: memoryKey ?? 'chat_history', chatHistory: redisChatMessageHistory, - isSessionIdUsingChatMessageId + isSessionIdUsingChatMessageId, + sessionId, + redisClient: client }) return memory } interface BufferMemoryExtendedInput { isSessionIdUsingChatMessageId: boolean + redisClient: Redis + sessionId: string } class BufferMemoryExtended extends BufferMemory { isSessionIdUsingChatMessageId? = false + sessionId = '' + redisClient: Redis - constructor(fields: BufferMemoryInput & Partial) { + constructor(fields: BufferMemoryInput & BufferMemoryExtendedInput) { super(fields) this.isSessionIdUsingChatMessageId = fields.isSessionIdUsingChatMessageId + this.sessionId = fields.sessionId + this.redisClient = fields.redisClient + } + + async getChatMessages(overrideSessionId = '', returnBaseMessage = false): Promise { + if (!this.redisClient) return [] + + const id = overrideSessionId ?? this.sessionId + const rawStoredMessages = await this.redisClient.lrange(id, 0, -1) + const orderedMessages = rawStoredMessages.reverse().map((message) => JSON.parse(message)) + const baseMessages = orderedMessages.map(mapStoredMessageToChatMessage) + return returnBaseMessage ? baseMessages : convertBaseMessagetoIMessage(baseMessages) + } + + async addChatMessages(msgArray: { text: string; type: MessageType }[], overrideSessionId = ''): Promise { + if (!this.redisClient) return + + const id = overrideSessionId ?? this.sessionId + const input = msgArray.find((msg) => msg.type === 'userMessage') + const output = msgArray.find((msg) => msg.type === 'apiMessage') + + if (input) { + const newInputMessage = new HumanMessage(input.text) + const messageToAdd = [newInputMessage].map((msg) => msg.toDict()) + await this.redisClient.lpush(id, JSON.stringify(messageToAdd[0])) + } + + if (output) { + const newOutputMessage = new AIMessage(output.text) + const messageToAdd = [newOutputMessage].map((msg) => msg.toDict()) + await this.redisClient.lpush(id, JSON.stringify(messageToAdd[0])) + } + } + + async clearChatMessages(overrideSessionId = ''): Promise { + if (!this.redisClient) return + + const id = overrideSessionId ?? this.sessionId + await this.redisClient.del(id) + await this.clear() } } diff --git a/packages/components/nodes/tools/CustomTool/core.ts b/packages/components/nodes/tools/CustomTool/core.ts index 7e76e37f8..77aa0e6b4 100644 --- a/packages/components/nodes/tools/CustomTool/core.ts +++ b/packages/components/nodes/tools/CustomTool/core.ts @@ -1,8 +1,18 @@ import { z } from 'zod' -import { CallbackManagerForToolRun } from 'langchain/callbacks' -import { StructuredTool, ToolParams } from 'langchain/tools' import { NodeVM } from 'vm2' import { availableDependencies } from '../../../src/utils' +import { RunnableConfig } from '@langchain/core/runnables' +import { StructuredTool, ToolParams } from '@langchain/core/tools' +import { CallbackManagerForToolRun, Callbacks, CallbackManager, parseCallbackConfigArg } from '@langchain/core/callbacks/manager' + +class ToolInputParsingException extends Error { + output?: string + + constructor(message: string, output?: string) { + super(message) + this.output = output + } +} export interface BaseDynamicToolInput extends ToolParams { name: string @@ -45,7 +55,47 @@ export class DynamicStructuredTool< this.schema = fields.schema } - protected async _call(arg: z.output): Promise { + async call(arg: z.output, configArg?: RunnableConfig | Callbacks, tags?: string[], overrideSessionId?: string): Promise { + const config = parseCallbackConfigArg(configArg) + if (config.runName === undefined) { + config.runName = this.name + } + let parsed + try { + parsed = await this.schema.parseAsync(arg) + } catch (e) { + throw new ToolInputParsingException(`Received tool input did not match expected schema`, JSON.stringify(arg)) + } + const callbackManager_ = await CallbackManager.configure( + config.callbacks, + this.callbacks, + config.tags || tags, + this.tags, + config.metadata, + this.metadata, + { verbose: this.verbose } + ) + const runManager = await callbackManager_?.handleToolStart( + this.toJSON(), + typeof parsed === 'string' ? parsed : JSON.stringify(parsed), + undefined, + undefined, + undefined, + undefined, + config.runName + ) + let result + try { + result = await this._call(parsed, runManager, overrideSessionId) + } catch (e) { + await runManager?.handleToolError(e) + throw e + } + await runManager?.handleToolEnd(result) + return result + } + + protected async _call(arg: z.output, _?: CallbackManagerForToolRun, overrideSessionId?: string): Promise { let sandbox: any = {} if (typeof arg === 'object' && Object.keys(arg).length) { for (const item in arg) { @@ -70,7 +120,7 @@ export class DynamicStructuredTool< } sandbox['$env'] = env if (this.flowObj) { - sandbox['$flow'] = this.flowObj + sandbox['$flow'] = { ...this.flowObj, sessionId: overrideSessionId } } const defaultAllowBuiltInDep = [ 'assert', diff --git a/packages/components/src/utils.ts b/packages/components/src/utils.ts index 239b13ca8..2f6855d70 100644 --- a/packages/components/src/utils.ts +++ b/packages/components/src/utils.ts @@ -8,7 +8,7 @@ import { DataSource } from 'typeorm' import { ICommonObject, IDatabaseEntity, IMessage, INodeData } from './Interface' import { AES, enc } from 'crypto-js' import { ChatMessageHistory } from 'langchain/memory' -import { AIMessage, HumanMessage } from 'langchain/schema' +import { AIMessage, HumanMessage, BaseMessage } from 'langchain/schema' export const numberOrExpressionRegex = '^(\\d+\\.?\\d*|{{.*}})$' //return true if string consists only numbers OR expression {{}} export const notEmptyRegex = '(.|\\s)*\\S(.|\\s)*' //return true if string is not empty or blank @@ -644,3 +644,31 @@ export const convertSchemaToZod = (schema: string | object): ICommonObject => { throw new Error(e) } } + +/** + * Convert BaseMessage to IMessage + * @param {BaseMessage[]} messages + * @returns {IMessage[]} + */ +export const convertBaseMessagetoIMessage = (messages: BaseMessage[]): IMessage[] => { + const formatmessages: IMessage[] = [] + for (const m of messages) { + if (m._getType() === 'human') { + formatmessages.push({ + message: m.content as string, + type: 'userMessage' + }) + } else if (m._getType() === 'ai') { + formatmessages.push({ + message: m.content as string, + type: 'apiMessage' + }) + } else if (m._getType() === 'system') { + formatmessages.push({ + message: m.content as string, + type: 'apiMessage' + }) + } + } + return formatmessages +} diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 99fcba70e..531533c1b 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -44,8 +44,7 @@ import { checkMemorySessionId, clearSessionMemoryFromViewMessageDialog, getUserHome, - replaceChatHistory, - replaceEnvVariables + replaceChatHistory } from './utils' import { cloneDeep, omit, uniqWith, isEqual } from 'lodash' import { getDataSource } from './DataSource' @@ -1617,10 +1616,6 @@ export class App { this.chatflowPool.add(chatflowid, nodeToExecuteData, startingNodes, incomingInput?.overrideConfig) } - const nodeInstanceFilePath = this.nodesPool.componentNodes[nodeToExecuteData.name].filePath as string - const nodeModule = await import(nodeInstanceFilePath) - const nodeInstance = new nodeModule.nodeClass() - logger.debug(`[server]: Running ${nodeToExecuteData.label} (${nodeToExecuteData.id})`) let sessionId = undefined @@ -1634,6 +1629,10 @@ export class App { chatHistory = await replaceChatHistory(memoryNode, incomingInput, this.AppDataSource, databaseEntities, logger) } + const nodeInstanceFilePath = this.nodesPool.componentNodes[nodeToExecuteData.name].filePath as string + const nodeModule = await import(nodeInstanceFilePath) + const nodeInstance = new nodeModule.nodeClass({ sessionId }) + let result = isStreamValid ? await nodeInstance.run(nodeToExecuteData, incomingInput.question, { chatflowid,