Passing state to tool so that we can use them in custom tools (#3103)
This commit is contained in:
parent
7a5246d28a
commit
2e45851822
|
|
@ -1,5 +1,14 @@
|
||||||
import { flatten } from 'lodash'
|
import { flatten } from 'lodash'
|
||||||
import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeParams, ISeqAgentNode, IUsedTool } from '../../../src/Interface'
|
import {
|
||||||
|
ICommonObject,
|
||||||
|
IDatabaseEntity,
|
||||||
|
INode,
|
||||||
|
INodeData,
|
||||||
|
INodeParams,
|
||||||
|
ISeqAgentNode,
|
||||||
|
IUsedTool,
|
||||||
|
IStateWithMessages
|
||||||
|
} from '../../../src/Interface'
|
||||||
import { AIMessage, AIMessageChunk, BaseMessage, ToolMessage } from '@langchain/core/messages'
|
import { AIMessage, AIMessageChunk, BaseMessage, ToolMessage } from '@langchain/core/messages'
|
||||||
import { StructuredTool } from '@langchain/core/tools'
|
import { StructuredTool } from '@langchain/core/tools'
|
||||||
import { RunnableConfig } from '@langchain/core/runnables'
|
import { RunnableConfig } from '@langchain/core/runnables'
|
||||||
|
|
@ -9,6 +18,7 @@ import { DataSource } from 'typeorm'
|
||||||
import { MessagesState, RunnableCallable, customGet, getVM } from '../commonUtils'
|
import { MessagesState, RunnableCallable, customGet, getVM } from '../commonUtils'
|
||||||
import { getVars, prepareSandboxVars } from '../../../src/utils'
|
import { getVars, prepareSandboxVars } from '../../../src/utils'
|
||||||
import { ChatPromptTemplate } from '@langchain/core/prompts'
|
import { ChatPromptTemplate } from '@langchain/core/prompts'
|
||||||
|
import { DynamicStructuredTool } from '../../tools/CustomTool/core'
|
||||||
|
|
||||||
const defaultApprovalPrompt = `You are about to execute tool: {tools}. Ask if user want to proceed`
|
const defaultApprovalPrompt = `You are about to execute tool: {tools}. Ask if user want to proceed`
|
||||||
|
|
||||||
|
|
@ -350,7 +360,7 @@ class ToolNode_SeqAgents implements INode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable<T, T> {
|
class ToolNode<T extends IStateWithMessages | BaseMessage[] | MessagesState> extends RunnableCallable<T, BaseMessage[] | MessagesState> {
|
||||||
tools: StructuredTool[]
|
tools: StructuredTool[]
|
||||||
nodeData: INodeData
|
nodeData: INodeData
|
||||||
inputQuery: string
|
inputQuery: string
|
||||||
|
|
@ -372,19 +382,45 @@ class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable
|
||||||
this.options = options
|
this.options = options
|
||||||
}
|
}
|
||||||
|
|
||||||
private async run(input: BaseMessage[] | MessagesState, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> {
|
private async run(input: T, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> {
|
||||||
const message = Array.isArray(input) ? input[input.length - 1] : input.messages[input.messages.length - 1]
|
let messages: BaseMessage[]
|
||||||
|
|
||||||
|
// Check if input is an array of BaseMessage[]
|
||||||
|
if (Array.isArray(input)) {
|
||||||
|
messages = input
|
||||||
|
}
|
||||||
|
// Check if input is IStateWithMessages
|
||||||
|
else if ((input as IStateWithMessages).messages) {
|
||||||
|
messages = (input as IStateWithMessages).messages
|
||||||
|
}
|
||||||
|
// Handle MessagesState type
|
||||||
|
else {
|
||||||
|
messages = (input as MessagesState).messages
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the last message
|
||||||
|
const message = messages[messages.length - 1]
|
||||||
|
|
||||||
if (message._getType() !== 'ai') {
|
if (message._getType() !== 'ai') {
|
||||||
throw new Error('ToolNode only accepts AIMessages as input.')
|
throw new Error('ToolNode only accepts AIMessages as input.')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract all properties except messages for IStateWithMessages
|
||||||
|
const { messages: _, ...inputWithoutMessages } = Array.isArray(input) ? { messages: input } : input
|
||||||
|
const ChannelsWithoutMessages = {
|
||||||
|
state: inputWithoutMessages
|
||||||
|
}
|
||||||
|
|
||||||
const outputs = await Promise.all(
|
const outputs = await Promise.all(
|
||||||
(message as AIMessage).tool_calls?.map(async (call) => {
|
(message as AIMessage).tool_calls?.map(async (call) => {
|
||||||
const tool = this.tools.find((tool) => tool.name === call.name)
|
const tool = this.tools.find((tool) => tool.name === call.name)
|
||||||
if (tool === undefined) {
|
if (tool === undefined) {
|
||||||
throw new Error(`Tool ${call.name} not found.`)
|
throw new Error(`Tool ${call.name} not found.`)
|
||||||
}
|
}
|
||||||
|
if (tool && tool instanceof DynamicStructuredTool) {
|
||||||
|
// @ts-ignore
|
||||||
|
tool.setFlowObject(ChannelsWithoutMessages)
|
||||||
|
}
|
||||||
let output = await tool.invoke(call.args, config)
|
let output = await tool.invoke(call.args, config)
|
||||||
let sourceDocuments: Document[] = []
|
let sourceDocuments: Document[] = []
|
||||||
if (output?.includes(SOURCE_DOCUMENTS_PREFIX)) {
|
if (output?.includes(SOURCE_DOCUMENTS_PREFIX)) {
|
||||||
|
|
@ -436,7 +472,7 @@ const getReturnOutput = async (
|
||||||
input: string,
|
input: string,
|
||||||
options: ICommonObject,
|
options: ICommonObject,
|
||||||
outputs: ToolMessage[],
|
outputs: ToolMessage[],
|
||||||
state: BaseMessage[] | MessagesState
|
state: ICommonObject
|
||||||
) => {
|
) => {
|
||||||
const appDataSource = options.appDataSource as DataSource
|
const appDataSource = options.appDataSource as DataSource
|
||||||
const databaseEntities = options.databaseEntities as IDatabaseEntity
|
const databaseEntities = options.databaseEntities as IDatabaseEntity
|
||||||
|
|
|
||||||
|
|
@ -396,3 +396,7 @@ export interface IVisionChatModal {
|
||||||
revertToOriginalModel(): void
|
revertToOriginalModel(): void
|
||||||
setMultiModalOption(multiModalOption: IMultiModalOption): void
|
setMultiModalOption(multiModalOption: IMultiModalOption): void
|
||||||
}
|
}
|
||||||
|
export interface IStateWithMessages extends ICommonObject {
|
||||||
|
messages: BaseMessage[]
|
||||||
|
[key: string]: any
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue