Bugfix/Pass state to tool node for agents (#3139)
pass state to tool node for agents
This commit is contained in:
parent
2a21f18bf8
commit
e6918381a5
|
|
@ -18,7 +18,8 @@ import {
|
||||||
ISeqAgentNode,
|
ISeqAgentNode,
|
||||||
IDatabaseEntity,
|
IDatabaseEntity,
|
||||||
IUsedTool,
|
IUsedTool,
|
||||||
IDocument
|
IDocument,
|
||||||
|
IStateWithMessages
|
||||||
} from '../../../src/Interface'
|
} from '../../../src/Interface'
|
||||||
import { ToolCallingAgentOutputParser, AgentExecutor, SOURCE_DOCUMENTS_PREFIX } from '../../../src/agents'
|
import { ToolCallingAgentOutputParser, AgentExecutor, SOURCE_DOCUMENTS_PREFIX } from '../../../src/agents'
|
||||||
import { getInputVariables, getVars, handleEscapeCharacters, prepareSandboxVars } from '../../../src/utils'
|
import { getInputVariables, getVars, handleEscapeCharacters, prepareSandboxVars } from '../../../src/utils'
|
||||||
|
|
@ -34,6 +35,7 @@ import {
|
||||||
} from '../commonUtils'
|
} from '../commonUtils'
|
||||||
import { END, StateGraph } from '@langchain/langgraph'
|
import { END, StateGraph } from '@langchain/langgraph'
|
||||||
import { StructuredTool } from '@langchain/core/tools'
|
import { StructuredTool } from '@langchain/core/tools'
|
||||||
|
import { DynamicStructuredTool } from '../../tools/CustomTool/core'
|
||||||
|
|
||||||
const defaultApprovalPrompt = `You are about to execute tool: {tools}. Ask if user want to proceed`
|
const defaultApprovalPrompt = `You are about to execute tool: {tools}. Ask if user want to proceed`
|
||||||
const examplePrompt = 'You are a research assistant who can search for up-to-date info using search engine.'
|
const examplePrompt = 'You are a research assistant who can search for up-to-date info using search engine.'
|
||||||
|
|
@ -904,18 +906,44 @@ class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable
|
||||||
}
|
}
|
||||||
|
|
||||||
private async run(input: BaseMessage[] | MessagesState, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> {
|
private async run(input: BaseMessage[] | MessagesState, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> {
|
||||||
const message = Array.isArray(input) ? input[input.length - 1] : input.messages[input.messages.length - 1]
|
let messages: BaseMessage[]
|
||||||
|
|
||||||
|
// Check if input is an array of BaseMessage[]
|
||||||
|
if (Array.isArray(input)) {
|
||||||
|
messages = input
|
||||||
|
}
|
||||||
|
// Check if input is IStateWithMessages
|
||||||
|
else if ((input as IStateWithMessages).messages) {
|
||||||
|
messages = (input as IStateWithMessages).messages
|
||||||
|
}
|
||||||
|
// Handle MessagesState type
|
||||||
|
else {
|
||||||
|
messages = (input as MessagesState).messages
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the last message
|
||||||
|
const message = messages[messages.length - 1]
|
||||||
|
|
||||||
if (message._getType() !== 'ai') {
|
if (message._getType() !== 'ai') {
|
||||||
throw new Error('ToolNode only accepts AIMessages as input.')
|
throw new Error('ToolNode only accepts AIMessages as input.')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract all properties except messages for IStateWithMessages
|
||||||
|
const { messages: _, ...inputWithoutMessages } = Array.isArray(input) ? { messages: input } : input
|
||||||
|
const ChannelsWithoutMessages = {
|
||||||
|
state: inputWithoutMessages
|
||||||
|
}
|
||||||
|
|
||||||
const outputs = await Promise.all(
|
const outputs = await Promise.all(
|
||||||
(message as AIMessage).tool_calls?.map(async (call) => {
|
(message as AIMessage).tool_calls?.map(async (call) => {
|
||||||
const tool = this.tools.find((tool) => tool.name === call.name)
|
const tool = this.tools.find((tool) => tool.name === call.name)
|
||||||
if (tool === undefined) {
|
if (tool === undefined) {
|
||||||
throw new Error(`Tool ${call.name} not found.`)
|
throw new Error(`Tool ${call.name} not found.`)
|
||||||
}
|
}
|
||||||
|
if (tool && tool instanceof DynamicStructuredTool) {
|
||||||
|
// @ts-ignore
|
||||||
|
tool.setFlowObject(ChannelsWithoutMessages)
|
||||||
|
}
|
||||||
let output = await tool.invoke(call.args, config)
|
let output = await tool.invoke(call.args, config)
|
||||||
let sourceDocuments: Document[] = []
|
let sourceDocuments: Document[] = []
|
||||||
if (output?.includes(SOURCE_DOCUMENTS_PREFIX)) {
|
if (output?.includes(SOURCE_DOCUMENTS_PREFIX)) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue