Bugfix/Move summarization as optional for multi agents (#2858)
add summarization as optional for multi agents
This commit is contained in:
parent
c31a4c95e7
commit
368c69cbc5
|
|
@ -49,7 +49,7 @@ class Supervisor_MultiAgents implements INode {
|
||||||
constructor() {
|
constructor() {
|
||||||
this.label = 'Supervisor'
|
this.label = 'Supervisor'
|
||||||
this.name = 'supervisor'
|
this.name = 'supervisor'
|
||||||
this.version = 2.0
|
this.version = 3.0
|
||||||
this.type = 'Supervisor'
|
this.type = 'Supervisor'
|
||||||
this.icon = 'supervisor.svg'
|
this.icon = 'supervisor.svg'
|
||||||
this.category = 'Multi Agents'
|
this.category = 'Multi Agents'
|
||||||
|
|
@ -84,6 +84,14 @@ class Supervisor_MultiAgents implements INode {
|
||||||
description: 'Save the state of the agent',
|
description: 'Save the state of the agent',
|
||||||
optional: true
|
optional: true
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
label: 'Summarization',
|
||||||
|
name: 'summarization',
|
||||||
|
type: 'boolean',
|
||||||
|
description: 'Return final output as a summarization of the conversation',
|
||||||
|
optional: true,
|
||||||
|
additionalParams: true
|
||||||
|
},
|
||||||
{
|
{
|
||||||
label: 'Recursion Limit',
|
label: 'Recursion Limit',
|
||||||
name: 'recursionLimit',
|
name: 'recursionLimit',
|
||||||
|
|
@ -110,6 +118,7 @@ class Supervisor_MultiAgents implements INode {
|
||||||
const _recursionLimit = nodeData.inputs?.recursionLimit as string
|
const _recursionLimit = nodeData.inputs?.recursionLimit as string
|
||||||
const recursionLimit = _recursionLimit ? parseFloat(_recursionLimit) : 100
|
const recursionLimit = _recursionLimit ? parseFloat(_recursionLimit) : 100
|
||||||
const moderations = (nodeData.inputs?.inputModeration as Moderation[]) ?? []
|
const moderations = (nodeData.inputs?.inputModeration as Moderation[]) ?? []
|
||||||
|
const summarization = nodeData.inputs?.summarization as string
|
||||||
|
|
||||||
const abortControllerSignal = options.signal as AbortController
|
const abortControllerSignal = options.signal as AbortController
|
||||||
|
|
||||||
|
|
@ -128,6 +137,257 @@ class Supervisor_MultiAgents implements INode {
|
||||||
|
|
||||||
systemPrompt = systemPrompt.replaceAll('{team_members}', members.join(', '))
|
systemPrompt = systemPrompt.replaceAll('{team_members}', members.join(', '))
|
||||||
|
|
||||||
|
let userPrompt = `Given the conversation above, who should act next? Or should we FINISH? Select one of: ${memberOptions.join(
|
||||||
|
', '
|
||||||
|
)}`
|
||||||
|
|
||||||
|
const tool = new RouteTool({
|
||||||
|
schema: z.object({
|
||||||
|
reasoning: z.string(),
|
||||||
|
next: z.enum(['FINISH', ...members]),
|
||||||
|
instructions: z.string().describe('The specific instructions of the sub-task the next role should accomplish.')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
let supervisor
|
||||||
|
|
||||||
|
if (llm instanceof ChatMistralAI) {
|
||||||
|
let prompt = ChatPromptTemplate.fromMessages([
|
||||||
|
['system', systemPrompt],
|
||||||
|
new MessagesPlaceholder('messages'),
|
||||||
|
['human', userPrompt]
|
||||||
|
])
|
||||||
|
|
||||||
|
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
|
||||||
|
prompt = messages.prompt
|
||||||
|
multiModalMessageContent = messages.multiModalMessageContent
|
||||||
|
|
||||||
|
// Force Mistral to use tool
|
||||||
|
// @ts-ignore
|
||||||
|
const modelWithTool = llm.bind({
|
||||||
|
tools: [tool],
|
||||||
|
tool_choice: 'any',
|
||||||
|
signal: abortControllerSignal ? abortControllerSignal.signal : undefined
|
||||||
|
})
|
||||||
|
|
||||||
|
const outputParser = new JsonOutputToolsParser()
|
||||||
|
|
||||||
|
supervisor = prompt
|
||||||
|
.pipe(modelWithTool)
|
||||||
|
.pipe(outputParser)
|
||||||
|
.pipe((x) => {
|
||||||
|
if (Array.isArray(x) && x.length) {
|
||||||
|
const toolAgentAction = x[0]
|
||||||
|
return {
|
||||||
|
next: Object.keys(toolAgentAction.args).length ? toolAgentAction.args.next : 'FINISH',
|
||||||
|
instructions: Object.keys(toolAgentAction.args).length
|
||||||
|
? toolAgentAction.args.instructions
|
||||||
|
: 'Conversation finished',
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
next: 'FINISH',
|
||||||
|
instructions: 'Conversation finished',
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else if (llm instanceof ChatAnthropic) {
|
||||||
|
// Force Anthropic to use tool : https://docs.anthropic.com/claude/docs/tool-use#forcing-tool-use
|
||||||
|
userPrompt = `Given the conversation above, who should act next? Or should we FINISH? Select one of: ${memberOptions.join(
|
||||||
|
', '
|
||||||
|
)}. Use the ${routerToolName} tool in your response.`
|
||||||
|
|
||||||
|
let prompt = ChatPromptTemplate.fromMessages([
|
||||||
|
['system', systemPrompt],
|
||||||
|
new MessagesPlaceholder('messages'),
|
||||||
|
['human', userPrompt]
|
||||||
|
])
|
||||||
|
|
||||||
|
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
|
||||||
|
prompt = messages.prompt
|
||||||
|
multiModalMessageContent = messages.multiModalMessageContent
|
||||||
|
|
||||||
|
if (llm.bindTools === undefined) {
|
||||||
|
throw new Error(`This agent only compatible with function calling models.`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const modelWithTool = llm.bindTools([tool])
|
||||||
|
|
||||||
|
const outputParser = new ToolCallingAgentOutputParser()
|
||||||
|
|
||||||
|
supervisor = prompt
|
||||||
|
.pipe(modelWithTool)
|
||||||
|
.pipe(outputParser)
|
||||||
|
.pipe((x) => {
|
||||||
|
if (Array.isArray(x) && x.length) {
|
||||||
|
const toolAgentAction = x[0] as any
|
||||||
|
return {
|
||||||
|
next: toolAgentAction.toolInput.next,
|
||||||
|
instructions: toolAgentAction.toolInput.instructions,
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
} else if (typeof x === 'object' && 'returnValues' in x) {
|
||||||
|
return {
|
||||||
|
next: 'FINISH',
|
||||||
|
instructions: x.returnValues?.output,
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
next: 'FINISH',
|
||||||
|
instructions: 'Conversation finished',
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else if (llm instanceof ChatOpenAI) {
|
||||||
|
let prompt = ChatPromptTemplate.fromMessages([
|
||||||
|
['system', systemPrompt],
|
||||||
|
new MessagesPlaceholder('messages'),
|
||||||
|
['human', userPrompt]
|
||||||
|
])
|
||||||
|
|
||||||
|
// @ts-ignore
|
||||||
|
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
|
||||||
|
prompt = messages.prompt
|
||||||
|
multiModalMessageContent = messages.multiModalMessageContent
|
||||||
|
|
||||||
|
// Force OpenAI to use tool
|
||||||
|
const modelWithTool = llm.bind({
|
||||||
|
tools: [tool],
|
||||||
|
tool_choice: { type: 'function', function: { name: routerToolName } },
|
||||||
|
signal: abortControllerSignal ? abortControllerSignal.signal : undefined
|
||||||
|
})
|
||||||
|
|
||||||
|
const outputParser = new ToolCallingAgentOutputParser()
|
||||||
|
|
||||||
|
supervisor = prompt
|
||||||
|
.pipe(modelWithTool)
|
||||||
|
.pipe(outputParser)
|
||||||
|
.pipe((x) => {
|
||||||
|
if (Array.isArray(x) && x.length) {
|
||||||
|
const toolAgentAction = x[0] as any
|
||||||
|
return {
|
||||||
|
next: toolAgentAction.toolInput.next,
|
||||||
|
instructions: toolAgentAction.toolInput.instructions,
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
} else if (typeof x === 'object' && 'returnValues' in x) {
|
||||||
|
return {
|
||||||
|
next: 'FINISH',
|
||||||
|
instructions: x.returnValues?.output,
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
next: 'FINISH',
|
||||||
|
instructions: 'Conversation finished',
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else if (llm instanceof ChatGoogleGenerativeAI) {
|
||||||
|
/*
|
||||||
|
* Gemini doesn't have system message and messages have to be alternate between model and user
|
||||||
|
* So we have to place the system + human prompt at last
|
||||||
|
*/
|
||||||
|
let prompt = ChatPromptTemplate.fromMessages([
|
||||||
|
['system', systemPrompt],
|
||||||
|
new MessagesPlaceholder('messages'),
|
||||||
|
['human', userPrompt]
|
||||||
|
])
|
||||||
|
|
||||||
|
const messages = await processImageMessage(2, llm, prompt, nodeData, options)
|
||||||
|
prompt = messages.prompt
|
||||||
|
multiModalMessageContent = messages.multiModalMessageContent
|
||||||
|
|
||||||
|
if (llm.bindTools === undefined) {
|
||||||
|
throw new Error(`This agent only compatible with function calling models.`)
|
||||||
|
}
|
||||||
|
const modelWithTool = llm.bindTools([tool])
|
||||||
|
|
||||||
|
const outputParser = new ToolCallingAgentOutputParser()
|
||||||
|
|
||||||
|
supervisor = prompt
|
||||||
|
.pipe(modelWithTool)
|
||||||
|
.pipe(outputParser)
|
||||||
|
.pipe((x) => {
|
||||||
|
if (Array.isArray(x) && x.length) {
|
||||||
|
const toolAgentAction = x[0] as any
|
||||||
|
return {
|
||||||
|
next: toolAgentAction.toolInput.next,
|
||||||
|
instructions: toolAgentAction.toolInput.instructions,
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
} else if (typeof x === 'object' && 'returnValues' in x) {
|
||||||
|
return {
|
||||||
|
next: 'FINISH',
|
||||||
|
instructions: x.returnValues?.output,
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
next: 'FINISH',
|
||||||
|
instructions: 'Conversation finished',
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
let prompt = ChatPromptTemplate.fromMessages([
|
||||||
|
['system', systemPrompt],
|
||||||
|
new MessagesPlaceholder('messages'),
|
||||||
|
['human', userPrompt]
|
||||||
|
])
|
||||||
|
|
||||||
|
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
|
||||||
|
prompt = messages.prompt
|
||||||
|
multiModalMessageContent = messages.multiModalMessageContent
|
||||||
|
|
||||||
|
if (llm.bindTools === undefined) {
|
||||||
|
throw new Error(`This agent only compatible with function calling models.`)
|
||||||
|
}
|
||||||
|
const modelWithTool = llm.bindTools([tool])
|
||||||
|
|
||||||
|
const outputParser = new ToolCallingAgentOutputParser()
|
||||||
|
|
||||||
|
supervisor = prompt
|
||||||
|
.pipe(modelWithTool)
|
||||||
|
.pipe(outputParser)
|
||||||
|
.pipe((x) => {
|
||||||
|
if (Array.isArray(x) && x.length) {
|
||||||
|
const toolAgentAction = x[0] as any
|
||||||
|
return {
|
||||||
|
next: toolAgentAction.toolInput.next,
|
||||||
|
instructions: toolAgentAction.toolInput.instructions,
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
} else if (typeof x === 'object' && 'returnValues' in x) {
|
||||||
|
return {
|
||||||
|
next: 'FINISH',
|
||||||
|
instructions: x.returnValues?.output,
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
next: 'FINISH',
|
||||||
|
instructions: 'Conversation finished',
|
||||||
|
team_members: members.join(', ')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return supervisor
|
||||||
|
}
|
||||||
|
|
||||||
|
async function createTeamSupervisorWithSummarize(llm: BaseChatModel, systemPrompt: string, members: string[]): Promise<Runnable> {
|
||||||
|
const memberOptions = ['FINISH', ...members]
|
||||||
|
|
||||||
|
systemPrompt = systemPrompt.replaceAll('{team_members}', members.join(', '))
|
||||||
|
|
||||||
let userPrompt = `Given the conversation above, who should act next? Or should we FINISH? Select one of: ${memberOptions.join(
|
let userPrompt = `Given the conversation above, who should act next? Or should we FINISH? Select one of: ${memberOptions.join(
|
||||||
', '
|
', '
|
||||||
)}
|
)}
|
||||||
|
|
@ -247,7 +507,8 @@ class Supervisor_MultiAgents implements INode {
|
||||||
['human', userPrompt]
|
['human', userPrompt]
|
||||||
])
|
])
|
||||||
|
|
||||||
const messages = await processImageMessage(1, llm as any, prompt, nodeData, options)
|
// @ts-ignore
|
||||||
|
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
|
||||||
prompt = messages.prompt
|
prompt = messages.prompt
|
||||||
multiModalMessageContent = messages.multiModalMessageContent
|
multiModalMessageContent = messages.multiModalMessageContent
|
||||||
|
|
||||||
|
|
@ -389,7 +650,9 @@ class Supervisor_MultiAgents implements INode {
|
||||||
return supervisor
|
return supervisor
|
||||||
}
|
}
|
||||||
|
|
||||||
const supervisorAgent = await createTeamSupervisor(llm, supervisorPrompt ? supervisorPrompt : sysPrompt, workersNodeNames)
|
const supervisorAgent = summarization
|
||||||
|
? await createTeamSupervisorWithSummarize(llm, supervisorPrompt ? supervisorPrompt : sysPrompt, workersNodeNames)
|
||||||
|
: await createTeamSupervisor(llm, supervisorPrompt ? supervisorPrompt : sysPrompt, workersNodeNames)
|
||||||
|
|
||||||
const supervisorNode = async (state: ITeamState, config: RunnableConfig) =>
|
const supervisorNode = async (state: ITeamState, config: RunnableConfig) =>
|
||||||
await agentNode(
|
await agentNode(
|
||||||
|
|
@ -433,7 +696,7 @@ async function agentNode(
|
||||||
throw new Error('Aborted!')
|
throw new Error('Aborted!')
|
||||||
}
|
}
|
||||||
const result = await agent.invoke({ ...state, signal: abortControllerSignal.signal }, config)
|
const result = await agent.invoke({ ...state, signal: abortControllerSignal.signal }, config)
|
||||||
const additional_kwargs: ICommonObject = { nodeId }
|
const additional_kwargs: ICommonObject = { nodeId, type: 'supervisor' }
|
||||||
result.additional_kwargs = { ...result.additional_kwargs, ...additional_kwargs }
|
result.additional_kwargs = { ...result.additional_kwargs, ...additional_kwargs }
|
||||||
return result
|
return result
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|
|
||||||
|
|
@ -283,7 +283,7 @@ async function agentNode(
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await agent.invoke({ ...state, signal: abortControllerSignal.signal }, config)
|
const result = await agent.invoke({ ...state, signal: abortControllerSignal.signal }, config)
|
||||||
const additional_kwargs: ICommonObject = { nodeId }
|
const additional_kwargs: ICommonObject = { nodeId, type: 'worker' }
|
||||||
if (result.usedTools) {
|
if (result.usedTools) {
|
||||||
additional_kwargs.usedTools = result.usedTools
|
additional_kwargs.usedTools = result.usedTools
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -206,7 +206,7 @@ export interface ITeamState {
|
||||||
team_members: string[]
|
team_members: string[]
|
||||||
next: string
|
next: string
|
||||||
instructions: string
|
instructions: string
|
||||||
summarization: string
|
summarization?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ISeqAgentsState {
|
export interface ISeqAgentsState {
|
||||||
|
|
|
||||||
|
|
@ -147,6 +147,7 @@ export const buildAgentGraph = async (
|
||||||
let streamResults
|
let streamResults
|
||||||
let finalResult = ''
|
let finalResult = ''
|
||||||
let finalSummarization = ''
|
let finalSummarization = ''
|
||||||
|
let lastWorkerResult = ''
|
||||||
let agentReasoning: IAgentReasoning[] = []
|
let agentReasoning: IAgentReasoning[] = []
|
||||||
let isSequential = false
|
let isSequential = false
|
||||||
let lastMessageRaw = {} as AIMessageChunk
|
let lastMessageRaw = {} as AIMessageChunk
|
||||||
|
|
@ -182,7 +183,8 @@ export const buildAgentGraph = async (
|
||||||
incomingInput.question,
|
incomingInput.question,
|
||||||
chatHistory,
|
chatHistory,
|
||||||
incomingInput?.overrideConfig,
|
incomingInput?.overrideConfig,
|
||||||
sessionId || chatId
|
sessionId || chatId,
|
||||||
|
seqAgentNodes.some((node) => node.data.inputs?.summarization)
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
isSequential = true
|
isSequential = true
|
||||||
|
|
@ -277,6 +279,12 @@ export const buildAgentGraph = async (
|
||||||
|
|
||||||
finalSummarization = output[agentName]?.summarization ?? ''
|
finalSummarization = output[agentName]?.summarization ?? ''
|
||||||
|
|
||||||
|
lastWorkerResult =
|
||||||
|
output[agentName]?.messages?.length &&
|
||||||
|
output[agentName].messages[output[agentName].messages.length - 1]?.additional_kwargs?.type === 'worker'
|
||||||
|
? output[agentName].messages[output[agentName].messages.length - 1].content
|
||||||
|
: lastWorkerResult
|
||||||
|
|
||||||
if (socketIO && incomingInput.socketIOClientId) {
|
if (socketIO && incomingInput.socketIOClientId) {
|
||||||
if (!isStreamingStarted) {
|
if (!isStreamingStarted) {
|
||||||
isStreamingStarted = true
|
isStreamingStarted = true
|
||||||
|
|
@ -305,10 +313,13 @@ export const buildAgentGraph = async (
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* For multi agents mode, sometimes finalResult is empty
|
* For multi agents mode, sometimes finalResult is empty
|
||||||
* Provide summary as final result
|
* 1.) Provide lastWorkerResult as final result if available
|
||||||
|
* 2.) Provide summary as final result if available
|
||||||
*/
|
*/
|
||||||
if (!isSequential && !finalResult && finalSummarization) {
|
if (!isSequential && !finalResult) {
|
||||||
finalResult = finalSummarization
|
if (lastWorkerResult) finalResult = lastWorkerResult
|
||||||
|
else if (finalSummarization) finalResult = finalSummarization
|
||||||
|
|
||||||
if (socketIO && incomingInput.socketIOClientId) {
|
if (socketIO && incomingInput.socketIOClientId) {
|
||||||
socketIO.to(incomingInput.socketIOClientId).emit('token', finalResult)
|
socketIO.to(incomingInput.socketIOClientId).emit('token', finalResult)
|
||||||
}
|
}
|
||||||
|
|
@ -425,6 +436,7 @@ export const buildAgentGraph = async (
|
||||||
* @param {string} question
|
* @param {string} question
|
||||||
* @param {ICommonObject} overrideConfig
|
* @param {ICommonObject} overrideConfig
|
||||||
* @param {string} threadId
|
* @param {string} threadId
|
||||||
|
* @param {boolean} summarization
|
||||||
*/
|
*/
|
||||||
const compileMultiAgentsGraph = async (
|
const compileMultiAgentsGraph = async (
|
||||||
chatflow: IChatFlow,
|
chatflow: IChatFlow,
|
||||||
|
|
@ -437,7 +449,8 @@ const compileMultiAgentsGraph = async (
|
||||||
question: string,
|
question: string,
|
||||||
chatHistory: IMessage[] = [],
|
chatHistory: IMessage[] = [],
|
||||||
overrideConfig?: ICommonObject,
|
overrideConfig?: ICommonObject,
|
||||||
threadId?: string
|
threadId?: string,
|
||||||
|
summarization?: boolean
|
||||||
) => {
|
) => {
|
||||||
const appServer = getRunningExpressApp()
|
const appServer = getRunningExpressApp()
|
||||||
const channels: ITeamState = {
|
const channels: ITeamState = {
|
||||||
|
|
@ -447,10 +460,11 @@ const compileMultiAgentsGraph = async (
|
||||||
},
|
},
|
||||||
next: 'initialState',
|
next: 'initialState',
|
||||||
instructions: "Solve the user's request.",
|
instructions: "Solve the user's request.",
|
||||||
team_members: [],
|
team_members: []
|
||||||
summarization: 'summarize'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (summarization) channels.summarization = 'summarize'
|
||||||
|
|
||||||
const workflowGraph = new StateGraph<ITeamState>({
|
const workflowGraph = new StateGraph<ITeamState>({
|
||||||
//@ts-ignore
|
//@ts-ignore
|
||||||
channels
|
channels
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue