Passing state to tool so that we can use them in custom tools (#3103)

This commit is contained in:
Jrakru 2024-08-30 15:50:16 -04:00 committed by GitHub
parent 7a5246d28a
commit 2e45851822
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 5 deletions

View File

@ -1,5 +1,14 @@
import { flatten } from 'lodash'
import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeParams, ISeqAgentNode, IUsedTool } from '../../../src/Interface'
import {
ICommonObject,
IDatabaseEntity,
INode,
INodeData,
INodeParams,
ISeqAgentNode,
IUsedTool,
IStateWithMessages
} from '../../../src/Interface'
import { AIMessage, AIMessageChunk, BaseMessage, ToolMessage } from '@langchain/core/messages'
import { StructuredTool } from '@langchain/core/tools'
import { RunnableConfig } from '@langchain/core/runnables'
@ -9,6 +18,7 @@ import { DataSource } from 'typeorm'
import { MessagesState, RunnableCallable, customGet, getVM } from '../commonUtils'
import { getVars, prepareSandboxVars } from '../../../src/utils'
import { ChatPromptTemplate } from '@langchain/core/prompts'
import { DynamicStructuredTool } from '../../tools/CustomTool/core'
const defaultApprovalPrompt = `You are about to execute tool: {tools}. Ask if user want to proceed`
@ -350,7 +360,7 @@ class ToolNode_SeqAgents implements INode {
}
}
class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable<T, T> {
class ToolNode<T extends IStateWithMessages | BaseMessage[] | MessagesState> extends RunnableCallable<T, BaseMessage[] | MessagesState> {
tools: StructuredTool[]
nodeData: INodeData
inputQuery: string
@ -372,19 +382,45 @@ class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable
this.options = options
}
private async run(input: BaseMessage[] | MessagesState, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> {
const message = Array.isArray(input) ? input[input.length - 1] : input.messages[input.messages.length - 1]
private async run(input: T, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> {
let messages: BaseMessage[]
// Check if input is an array of BaseMessage[]
if (Array.isArray(input)) {
messages = input
}
// Check if input is IStateWithMessages
else if ((input as IStateWithMessages).messages) {
messages = (input as IStateWithMessages).messages
}
// Handle MessagesState type
else {
messages = (input as MessagesState).messages
}
// Get the last message
const message = messages[messages.length - 1]
if (message._getType() !== 'ai') {
throw new Error('ToolNode only accepts AIMessages as input.')
}
// Extract all properties except messages for IStateWithMessages
const { messages: _, ...inputWithoutMessages } = Array.isArray(input) ? { messages: input } : input
const ChannelsWithoutMessages = {
state: inputWithoutMessages
}
const outputs = await Promise.all(
(message as AIMessage).tool_calls?.map(async (call) => {
const tool = this.tools.find((tool) => tool.name === call.name)
if (tool === undefined) {
throw new Error(`Tool ${call.name} not found.`)
}
if (tool && tool instanceof DynamicStructuredTool) {
// @ts-ignore
tool.setFlowObject(ChannelsWithoutMessages)
}
let output = await tool.invoke(call.args, config)
let sourceDocuments: Document[] = []
if (output?.includes(SOURCE_DOCUMENTS_PREFIX)) {
@ -436,7 +472,7 @@ const getReturnOutput = async (
input: string,
options: ICommonObject,
outputs: ToolMessage[],
state: BaseMessage[] | MessagesState
state: ICommonObject
) => {
const appDataSource = options.appDataSource as DataSource
const databaseEntities = options.databaseEntities as IDatabaseEntity

View File

@ -396,3 +396,7 @@ export interface IVisionChatModal {
revertToOriginalModel(): void
setMultiModalOption(multiModalOption: IMultiModalOption): void
}
export interface IStateWithMessages extends ICommonObject {
messages: BaseMessage[]
[key: string]: any
}