Bugfix/stream custom tool return direct (#3003)

stream custom tool return direct
This commit is contained in:
Henry Heng 2024-08-12 18:35:15 +01:00 committed by GitHub
parent f57dc2477f
commit b9f0ec3a3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 44 additions and 1 deletions

View File

@ -136,6 +136,17 @@ class ConversationalAgent_Agents implements INode {
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools) options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
usedTools = res.usedTools usedTools = res.usedTools
} }
// If the tool is set to returnDirect, stream the output to the client
if (res.usedTools && res.usedTools.length) {
let inputTools = nodeData.inputs?.tools
inputTools = flatten(inputTools)
for (const tool of res.usedTools) {
const inputTool = inputTools.find((inputTool: Tool) => inputTool.name === tool.tool)
if (inputTool && inputTool.returnDirect) {
options.socketIO.to(options.socketIOClientId).emit('token', tool.toolOutput)
}
}
}
} else { } else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
if (res.sourceDocuments) { if (res.sourceDocuments) {

View File

@ -1,4 +1,5 @@
import { flatten } from 'lodash' import { flatten } from 'lodash'
import { Tool } from '@langchain/core/tools'
import { BaseMessage } from '@langchain/core/messages' import { BaseMessage } from '@langchain/core/messages'
import { ChainValues } from '@langchain/core/utils/types' import { ChainValues } from '@langchain/core/utils/types'
import { RunnableSequence } from '@langchain/core/runnables' import { RunnableSequence } from '@langchain/core/runnables'
@ -125,6 +126,17 @@ class ToolAgent_Agents implements INode {
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools) options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
usedTools = res.usedTools usedTools = res.usedTools
} }
// If the tool is set to returnDirect, stream the output to the client
if (res.usedTools && res.usedTools.length) {
let inputTools = nodeData.inputs?.tools
inputTools = flatten(inputTools)
for (const tool of res.usedTools) {
const inputTool = inputTools.find((inputTool: Tool) => inputTool.name === tool.tool)
if (inputTool && inputTool.returnDirect) {
options.socketIO.to(options.socketIOClientId).emit('token', tool.toolOutput)
}
}
}
} else { } else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
if (res.sourceDocuments) { if (res.sourceDocuments) {

View File

@ -142,6 +142,17 @@ class XMLAgent_Agents implements INode {
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools) options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
usedTools = res.usedTools usedTools = res.usedTools
} }
// If the tool is set to returnDirect, stream the output to the client
if (res.usedTools && res.usedTools.length) {
let inputTools = nodeData.inputs?.tools
inputTools = flatten(inputTools)
for (const tool of res.usedTools) {
const inputTool = inputTools.find((inputTool: Tool) => inputTool.name === tool.tool)
if (inputTool && inputTool.returnDirect) {
options.socketIO.to(options.socketIOClientId).emit('token', tool.toolOutput)
}
}
}
} else { } else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
if (res.sourceDocuments) { if (res.sourceDocuments) {

View File

@ -18,7 +18,7 @@ class CustomTool_Tools implements INode {
constructor() { constructor() {
this.label = 'Custom Tool' this.label = 'Custom Tool'
this.name = 'customTool' this.name = 'customTool'
this.version = 1.0 this.version = 2.0
this.type = 'CustomTool' this.type = 'CustomTool'
this.icon = 'customtool.svg' this.icon = 'customtool.svg'
this.category = 'Tools' this.category = 'Tools'
@ -29,6 +29,13 @@ class CustomTool_Tools implements INode {
name: 'selectedTool', name: 'selectedTool',
type: 'asyncOptions', type: 'asyncOptions',
loadMethod: 'listTools' loadMethod: 'listTools'
},
{
label: 'Return Direct',
name: 'returnDirect',
description: 'Return the output of the tool directly to the user',
type: 'boolean',
optional: true
} }
] ]
this.baseClasses = [this.type, 'Tool', ...getBaseClasses(DynamicStructuredTool)] this.baseClasses = [this.type, 'Tool', ...getBaseClasses(DynamicStructuredTool)]
@ -66,6 +73,7 @@ class CustomTool_Tools implements INode {
const customToolName = nodeData.inputs?.customToolName as string const customToolName = nodeData.inputs?.customToolName as string
const customToolDesc = nodeData.inputs?.customToolDesc as string const customToolDesc = nodeData.inputs?.customToolDesc as string
const customToolSchema = nodeData.inputs?.customToolSchema as string const customToolSchema = nodeData.inputs?.customToolSchema as string
const customToolReturnDirect = nodeData.inputs?.returnDirect as boolean
const appDataSource = options.appDataSource as DataSource const appDataSource = options.appDataSource as DataSource
const databaseEntities = options.databaseEntities as IDatabaseEntity const databaseEntities = options.databaseEntities as IDatabaseEntity
@ -97,6 +105,7 @@ class CustomTool_Tools implements INode {
let dynamicStructuredTool = new DynamicStructuredTool(obj) let dynamicStructuredTool = new DynamicStructuredTool(obj)
dynamicStructuredTool.setVariables(variables) dynamicStructuredTool.setVariables(variables)
dynamicStructuredTool.setFlowObject(flow) dynamicStructuredTool.setFlowObject(flow)
dynamicStructuredTool.returnDirect = customToolReturnDirect
return dynamicStructuredTool return dynamicStructuredTool
} catch (e) { } catch (e) {