Bugfix/Pass state to tool node for agents (#3139)

pass state to tool node for agents
This commit is contained in:
Henry Heng 2024-09-03 22:26:37 +01:00 committed by GitHub
parent 2a21f18bf8
commit e6918381a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 30 additions and 2 deletions

View File

@ -18,7 +18,8 @@ import {
ISeqAgentNode, ISeqAgentNode,
IDatabaseEntity, IDatabaseEntity,
IUsedTool, IUsedTool,
IDocument IDocument,
IStateWithMessages
} from '../../../src/Interface' } from '../../../src/Interface'
import { ToolCallingAgentOutputParser, AgentExecutor, SOURCE_DOCUMENTS_PREFIX } from '../../../src/agents' import { ToolCallingAgentOutputParser, AgentExecutor, SOURCE_DOCUMENTS_PREFIX } from '../../../src/agents'
import { getInputVariables, getVars, handleEscapeCharacters, prepareSandboxVars } from '../../../src/utils' import { getInputVariables, getVars, handleEscapeCharacters, prepareSandboxVars } from '../../../src/utils'
@ -34,6 +35,7 @@ import {
} from '../commonUtils' } from '../commonUtils'
import { END, StateGraph } from '@langchain/langgraph' import { END, StateGraph } from '@langchain/langgraph'
import { StructuredTool } from '@langchain/core/tools' import { StructuredTool } from '@langchain/core/tools'
import { DynamicStructuredTool } from '../../tools/CustomTool/core'
const defaultApprovalPrompt = `You are about to execute tool: {tools}. Ask if user want to proceed` const defaultApprovalPrompt = `You are about to execute tool: {tools}. Ask if user want to proceed`
const examplePrompt = 'You are a research assistant who can search for up-to-date info using search engine.' const examplePrompt = 'You are a research assistant who can search for up-to-date info using search engine.'
@ -904,18 +906,44 @@ class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable
} }
private async run(input: BaseMessage[] | MessagesState, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> { private async run(input: BaseMessage[] | MessagesState, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> {
const message = Array.isArray(input) ? input[input.length - 1] : input.messages[input.messages.length - 1] let messages: BaseMessage[]
// Check if input is an array of BaseMessage[]
if (Array.isArray(input)) {
messages = input
}
// Check if input is IStateWithMessages
else if ((input as IStateWithMessages).messages) {
messages = (input as IStateWithMessages).messages
}
// Handle MessagesState type
else {
messages = (input as MessagesState).messages
}
// Get the last message
const message = messages[messages.length - 1]
if (message._getType() !== 'ai') { if (message._getType() !== 'ai') {
throw new Error('ToolNode only accepts AIMessages as input.') throw new Error('ToolNode only accepts AIMessages as input.')
} }
// Extract all properties except messages for IStateWithMessages
const { messages: _, ...inputWithoutMessages } = Array.isArray(input) ? { messages: input } : input
const ChannelsWithoutMessages = {
state: inputWithoutMessages
}
const outputs = await Promise.all( const outputs = await Promise.all(
(message as AIMessage).tool_calls?.map(async (call) => { (message as AIMessage).tool_calls?.map(async (call) => {
const tool = this.tools.find((tool) => tool.name === call.name) const tool = this.tools.find((tool) => tool.name === call.name)
if (tool === undefined) { if (tool === undefined) {
throw new Error(`Tool ${call.name} not found.`) throw new Error(`Tool ${call.name} not found.`)
} }
if (tool && tool instanceof DynamicStructuredTool) {
// @ts-ignore
tool.setFlowObject(ChannelsWithoutMessages)
}
let output = await tool.invoke(call.args, config) let output = await tool.invoke(call.args, config)
let sourceDocuments: Document[] = [] let sourceDocuments: Document[] = []
if (output?.includes(SOURCE_DOCUMENTS_PREFIX)) { if (output?.includes(SOURCE_DOCUMENTS_PREFIX)) {