Passing state to tool so that we can use them in custom tools (#3103)
This commit is contained in:
parent
7a5246d28a
commit
2e45851822
|
|
@ -1,5 +1,14 @@
|
|||
import { flatten } from 'lodash'
|
||||
import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeParams, ISeqAgentNode, IUsedTool } from '../../../src/Interface'
|
||||
import {
|
||||
ICommonObject,
|
||||
IDatabaseEntity,
|
||||
INode,
|
||||
INodeData,
|
||||
INodeParams,
|
||||
ISeqAgentNode,
|
||||
IUsedTool,
|
||||
IStateWithMessages
|
||||
} from '../../../src/Interface'
|
||||
import { AIMessage, AIMessageChunk, BaseMessage, ToolMessage } from '@langchain/core/messages'
|
||||
import { StructuredTool } from '@langchain/core/tools'
|
||||
import { RunnableConfig } from '@langchain/core/runnables'
|
||||
|
|
@ -9,6 +18,7 @@ import { DataSource } from 'typeorm'
|
|||
import { MessagesState, RunnableCallable, customGet, getVM } from '../commonUtils'
|
||||
import { getVars, prepareSandboxVars } from '../../../src/utils'
|
||||
import { ChatPromptTemplate } from '@langchain/core/prompts'
|
||||
import { DynamicStructuredTool } from '../../tools/CustomTool/core'
|
||||
|
||||
const defaultApprovalPrompt = `You are about to execute tool: {tools}. Ask if user want to proceed`
|
||||
|
||||
|
|
@ -350,7 +360,7 @@ class ToolNode_SeqAgents implements INode {
|
|||
}
|
||||
}
|
||||
|
||||
class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable<T, T> {
|
||||
class ToolNode<T extends IStateWithMessages | BaseMessage[] | MessagesState> extends RunnableCallable<T, BaseMessage[] | MessagesState> {
|
||||
tools: StructuredTool[]
|
||||
nodeData: INodeData
|
||||
inputQuery: string
|
||||
|
|
@ -372,19 +382,45 @@ class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable
|
|||
this.options = options
|
||||
}
|
||||
|
||||
private async run(input: BaseMessage[] | MessagesState, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> {
|
||||
const message = Array.isArray(input) ? input[input.length - 1] : input.messages[input.messages.length - 1]
|
||||
private async run(input: T, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> {
|
||||
let messages: BaseMessage[]
|
||||
|
||||
// Check if input is an array of BaseMessage[]
|
||||
if (Array.isArray(input)) {
|
||||
messages = input
|
||||
}
|
||||
// Check if input is IStateWithMessages
|
||||
else if ((input as IStateWithMessages).messages) {
|
||||
messages = (input as IStateWithMessages).messages
|
||||
}
|
||||
// Handle MessagesState type
|
||||
else {
|
||||
messages = (input as MessagesState).messages
|
||||
}
|
||||
|
||||
// Get the last message
|
||||
const message = messages[messages.length - 1]
|
||||
|
||||
if (message._getType() !== 'ai') {
|
||||
throw new Error('ToolNode only accepts AIMessages as input.')
|
||||
}
|
||||
|
||||
// Extract all properties except messages for IStateWithMessages
|
||||
const { messages: _, ...inputWithoutMessages } = Array.isArray(input) ? { messages: input } : input
|
||||
const ChannelsWithoutMessages = {
|
||||
state: inputWithoutMessages
|
||||
}
|
||||
|
||||
const outputs = await Promise.all(
|
||||
(message as AIMessage).tool_calls?.map(async (call) => {
|
||||
const tool = this.tools.find((tool) => tool.name === call.name)
|
||||
if (tool === undefined) {
|
||||
throw new Error(`Tool ${call.name} not found.`)
|
||||
}
|
||||
if (tool && tool instanceof DynamicStructuredTool) {
|
||||
// @ts-ignore
|
||||
tool.setFlowObject(ChannelsWithoutMessages)
|
||||
}
|
||||
let output = await tool.invoke(call.args, config)
|
||||
let sourceDocuments: Document[] = []
|
||||
if (output?.includes(SOURCE_DOCUMENTS_PREFIX)) {
|
||||
|
|
@ -436,7 +472,7 @@ const getReturnOutput = async (
|
|||
input: string,
|
||||
options: ICommonObject,
|
||||
outputs: ToolMessage[],
|
||||
state: BaseMessage[] | MessagesState
|
||||
state: ICommonObject
|
||||
) => {
|
||||
const appDataSource = options.appDataSource as DataSource
|
||||
const databaseEntities = options.databaseEntities as IDatabaseEntity
|
||||
|
|
|
|||
|
|
@ -396,3 +396,7 @@ export interface IVisionChatModal {
|
|||
revertToOriginalModel(): void
|
||||
setMultiModalOption(multiModalOption: IMultiModalOption): void
|
||||
}
|
||||
export interface IStateWithMessages extends ICommonObject {
|
||||
messages: BaseMessage[]
|
||||
[key: string]: any
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue