Bugfix/Move summarization as optional for multi agents (#2858)
add summarization as optional for multi agents
This commit is contained in:
parent
c31a4c95e7
commit
368c69cbc5
|
|
@ -49,7 +49,7 @@ class Supervisor_MultiAgents implements INode {
|
|||
constructor() {
|
||||
this.label = 'Supervisor'
|
||||
this.name = 'supervisor'
|
||||
this.version = 2.0
|
||||
this.version = 3.0
|
||||
this.type = 'Supervisor'
|
||||
this.icon = 'supervisor.svg'
|
||||
this.category = 'Multi Agents'
|
||||
|
|
@ -84,6 +84,14 @@ class Supervisor_MultiAgents implements INode {
|
|||
description: 'Save the state of the agent',
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
label: 'Summarization',
|
||||
name: 'summarization',
|
||||
type: 'boolean',
|
||||
description: 'Return final output as a summarization of the conversation',
|
||||
optional: true,
|
||||
additionalParams: true
|
||||
},
|
||||
{
|
||||
label: 'Recursion Limit',
|
||||
name: 'recursionLimit',
|
||||
|
|
@ -110,6 +118,7 @@ class Supervisor_MultiAgents implements INode {
|
|||
const _recursionLimit = nodeData.inputs?.recursionLimit as string
|
||||
const recursionLimit = _recursionLimit ? parseFloat(_recursionLimit) : 100
|
||||
const moderations = (nodeData.inputs?.inputModeration as Moderation[]) ?? []
|
||||
const summarization = nodeData.inputs?.summarization as string
|
||||
|
||||
const abortControllerSignal = options.signal as AbortController
|
||||
|
||||
|
|
@ -128,6 +137,257 @@ class Supervisor_MultiAgents implements INode {
|
|||
|
||||
systemPrompt = systemPrompt.replaceAll('{team_members}', members.join(', '))
|
||||
|
||||
let userPrompt = `Given the conversation above, who should act next? Or should we FINISH? Select one of: ${memberOptions.join(
|
||||
', '
|
||||
)}`
|
||||
|
||||
const tool = new RouteTool({
|
||||
schema: z.object({
|
||||
reasoning: z.string(),
|
||||
next: z.enum(['FINISH', ...members]),
|
||||
instructions: z.string().describe('The specific instructions of the sub-task the next role should accomplish.')
|
||||
})
|
||||
})
|
||||
|
||||
let supervisor
|
||||
|
||||
if (llm instanceof ChatMistralAI) {
|
||||
let prompt = ChatPromptTemplate.fromMessages([
|
||||
['system', systemPrompt],
|
||||
new MessagesPlaceholder('messages'),
|
||||
['human', userPrompt]
|
||||
])
|
||||
|
||||
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
|
||||
prompt = messages.prompt
|
||||
multiModalMessageContent = messages.multiModalMessageContent
|
||||
|
||||
// Force Mistral to use tool
|
||||
// @ts-ignore
|
||||
const modelWithTool = llm.bind({
|
||||
tools: [tool],
|
||||
tool_choice: 'any',
|
||||
signal: abortControllerSignal ? abortControllerSignal.signal : undefined
|
||||
})
|
||||
|
||||
const outputParser = new JsonOutputToolsParser()
|
||||
|
||||
supervisor = prompt
|
||||
.pipe(modelWithTool)
|
||||
.pipe(outputParser)
|
||||
.pipe((x) => {
|
||||
if (Array.isArray(x) && x.length) {
|
||||
const toolAgentAction = x[0]
|
||||
return {
|
||||
next: Object.keys(toolAgentAction.args).length ? toolAgentAction.args.next : 'FINISH',
|
||||
instructions: Object.keys(toolAgentAction.args).length
|
||||
? toolAgentAction.args.instructions
|
||||
: 'Conversation finished',
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
next: 'FINISH',
|
||||
instructions: 'Conversation finished',
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
}
|
||||
})
|
||||
} else if (llm instanceof ChatAnthropic) {
|
||||
// Force Anthropic to use tool : https://docs.anthropic.com/claude/docs/tool-use#forcing-tool-use
|
||||
userPrompt = `Given the conversation above, who should act next? Or should we FINISH? Select one of: ${memberOptions.join(
|
||||
', '
|
||||
)}. Use the ${routerToolName} tool in your response.`
|
||||
|
||||
let prompt = ChatPromptTemplate.fromMessages([
|
||||
['system', systemPrompt],
|
||||
new MessagesPlaceholder('messages'),
|
||||
['human', userPrompt]
|
||||
])
|
||||
|
||||
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
|
||||
prompt = messages.prompt
|
||||
multiModalMessageContent = messages.multiModalMessageContent
|
||||
|
||||
if (llm.bindTools === undefined) {
|
||||
throw new Error(`This agent only compatible with function calling models.`)
|
||||
}
|
||||
|
||||
const modelWithTool = llm.bindTools([tool])
|
||||
|
||||
const outputParser = new ToolCallingAgentOutputParser()
|
||||
|
||||
supervisor = prompt
|
||||
.pipe(modelWithTool)
|
||||
.pipe(outputParser)
|
||||
.pipe((x) => {
|
||||
if (Array.isArray(x) && x.length) {
|
||||
const toolAgentAction = x[0] as any
|
||||
return {
|
||||
next: toolAgentAction.toolInput.next,
|
||||
instructions: toolAgentAction.toolInput.instructions,
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
} else if (typeof x === 'object' && 'returnValues' in x) {
|
||||
return {
|
||||
next: 'FINISH',
|
||||
instructions: x.returnValues?.output,
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
next: 'FINISH',
|
||||
instructions: 'Conversation finished',
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
}
|
||||
})
|
||||
} else if (llm instanceof ChatOpenAI) {
|
||||
let prompt = ChatPromptTemplate.fromMessages([
|
||||
['system', systemPrompt],
|
||||
new MessagesPlaceholder('messages'),
|
||||
['human', userPrompt]
|
||||
])
|
||||
|
||||
// @ts-ignore
|
||||
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
|
||||
prompt = messages.prompt
|
||||
multiModalMessageContent = messages.multiModalMessageContent
|
||||
|
||||
// Force OpenAI to use tool
|
||||
const modelWithTool = llm.bind({
|
||||
tools: [tool],
|
||||
tool_choice: { type: 'function', function: { name: routerToolName } },
|
||||
signal: abortControllerSignal ? abortControllerSignal.signal : undefined
|
||||
})
|
||||
|
||||
const outputParser = new ToolCallingAgentOutputParser()
|
||||
|
||||
supervisor = prompt
|
||||
.pipe(modelWithTool)
|
||||
.pipe(outputParser)
|
||||
.pipe((x) => {
|
||||
if (Array.isArray(x) && x.length) {
|
||||
const toolAgentAction = x[0] as any
|
||||
return {
|
||||
next: toolAgentAction.toolInput.next,
|
||||
instructions: toolAgentAction.toolInput.instructions,
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
} else if (typeof x === 'object' && 'returnValues' in x) {
|
||||
return {
|
||||
next: 'FINISH',
|
||||
instructions: x.returnValues?.output,
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
next: 'FINISH',
|
||||
instructions: 'Conversation finished',
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
}
|
||||
})
|
||||
} else if (llm instanceof ChatGoogleGenerativeAI) {
|
||||
/*
|
||||
* Gemini doesn't have system message and messages have to be alternate between model and user
|
||||
* So we have to place the system + human prompt at last
|
||||
*/
|
||||
let prompt = ChatPromptTemplate.fromMessages([
|
||||
['system', systemPrompt],
|
||||
new MessagesPlaceholder('messages'),
|
||||
['human', userPrompt]
|
||||
])
|
||||
|
||||
const messages = await processImageMessage(2, llm, prompt, nodeData, options)
|
||||
prompt = messages.prompt
|
||||
multiModalMessageContent = messages.multiModalMessageContent
|
||||
|
||||
if (llm.bindTools === undefined) {
|
||||
throw new Error(`This agent only compatible with function calling models.`)
|
||||
}
|
||||
const modelWithTool = llm.bindTools([tool])
|
||||
|
||||
const outputParser = new ToolCallingAgentOutputParser()
|
||||
|
||||
supervisor = prompt
|
||||
.pipe(modelWithTool)
|
||||
.pipe(outputParser)
|
||||
.pipe((x) => {
|
||||
if (Array.isArray(x) && x.length) {
|
||||
const toolAgentAction = x[0] as any
|
||||
return {
|
||||
next: toolAgentAction.toolInput.next,
|
||||
instructions: toolAgentAction.toolInput.instructions,
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
} else if (typeof x === 'object' && 'returnValues' in x) {
|
||||
return {
|
||||
next: 'FINISH',
|
||||
instructions: x.returnValues?.output,
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
next: 'FINISH',
|
||||
instructions: 'Conversation finished',
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
}
|
||||
})
|
||||
} else {
|
||||
let prompt = ChatPromptTemplate.fromMessages([
|
||||
['system', systemPrompt],
|
||||
new MessagesPlaceholder('messages'),
|
||||
['human', userPrompt]
|
||||
])
|
||||
|
||||
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
|
||||
prompt = messages.prompt
|
||||
multiModalMessageContent = messages.multiModalMessageContent
|
||||
|
||||
if (llm.bindTools === undefined) {
|
||||
throw new Error(`This agent only compatible with function calling models.`)
|
||||
}
|
||||
const modelWithTool = llm.bindTools([tool])
|
||||
|
||||
const outputParser = new ToolCallingAgentOutputParser()
|
||||
|
||||
supervisor = prompt
|
||||
.pipe(modelWithTool)
|
||||
.pipe(outputParser)
|
||||
.pipe((x) => {
|
||||
if (Array.isArray(x) && x.length) {
|
||||
const toolAgentAction = x[0] as any
|
||||
return {
|
||||
next: toolAgentAction.toolInput.next,
|
||||
instructions: toolAgentAction.toolInput.instructions,
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
} else if (typeof x === 'object' && 'returnValues' in x) {
|
||||
return {
|
||||
next: 'FINISH',
|
||||
instructions: x.returnValues?.output,
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
next: 'FINISH',
|
||||
instructions: 'Conversation finished',
|
||||
team_members: members.join(', ')
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return supervisor
|
||||
}
|
||||
|
||||
async function createTeamSupervisorWithSummarize(llm: BaseChatModel, systemPrompt: string, members: string[]): Promise<Runnable> {
|
||||
const memberOptions = ['FINISH', ...members]
|
||||
|
||||
systemPrompt = systemPrompt.replaceAll('{team_members}', members.join(', '))
|
||||
|
||||
let userPrompt = `Given the conversation above, who should act next? Or should we FINISH? Select one of: ${memberOptions.join(
|
||||
', '
|
||||
)}
|
||||
|
|
@ -247,7 +507,8 @@ class Supervisor_MultiAgents implements INode {
|
|||
['human', userPrompt]
|
||||
])
|
||||
|
||||
const messages = await processImageMessage(1, llm as any, prompt, nodeData, options)
|
||||
// @ts-ignore
|
||||
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
|
||||
prompt = messages.prompt
|
||||
multiModalMessageContent = messages.multiModalMessageContent
|
||||
|
||||
|
|
@ -389,7 +650,9 @@ class Supervisor_MultiAgents implements INode {
|
|||
return supervisor
|
||||
}
|
||||
|
||||
const supervisorAgent = await createTeamSupervisor(llm, supervisorPrompt ? supervisorPrompt : sysPrompt, workersNodeNames)
|
||||
const supervisorAgent = summarization
|
||||
? await createTeamSupervisorWithSummarize(llm, supervisorPrompt ? supervisorPrompt : sysPrompt, workersNodeNames)
|
||||
: await createTeamSupervisor(llm, supervisorPrompt ? supervisorPrompt : sysPrompt, workersNodeNames)
|
||||
|
||||
const supervisorNode = async (state: ITeamState, config: RunnableConfig) =>
|
||||
await agentNode(
|
||||
|
|
@ -433,7 +696,7 @@ async function agentNode(
|
|||
throw new Error('Aborted!')
|
||||
}
|
||||
const result = await agent.invoke({ ...state, signal: abortControllerSignal.signal }, config)
|
||||
const additional_kwargs: ICommonObject = { nodeId }
|
||||
const additional_kwargs: ICommonObject = { nodeId, type: 'supervisor' }
|
||||
result.additional_kwargs = { ...result.additional_kwargs, ...additional_kwargs }
|
||||
return result
|
||||
} catch (error) {
|
||||
|
|
|
|||
|
|
@ -283,7 +283,7 @@ async function agentNode(
|
|||
}
|
||||
|
||||
const result = await agent.invoke({ ...state, signal: abortControllerSignal.signal }, config)
|
||||
const additional_kwargs: ICommonObject = { nodeId }
|
||||
const additional_kwargs: ICommonObject = { nodeId, type: 'worker' }
|
||||
if (result.usedTools) {
|
||||
additional_kwargs.usedTools = result.usedTools
|
||||
}
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ export interface ITeamState {
|
|||
team_members: string[]
|
||||
next: string
|
||||
instructions: string
|
||||
summarization: string
|
||||
summarization?: string
|
||||
}
|
||||
|
||||
export interface ISeqAgentsState {
|
||||
|
|
|
|||
|
|
@ -147,6 +147,7 @@ export const buildAgentGraph = async (
|
|||
let streamResults
|
||||
let finalResult = ''
|
||||
let finalSummarization = ''
|
||||
let lastWorkerResult = ''
|
||||
let agentReasoning: IAgentReasoning[] = []
|
||||
let isSequential = false
|
||||
let lastMessageRaw = {} as AIMessageChunk
|
||||
|
|
@ -182,7 +183,8 @@ export const buildAgentGraph = async (
|
|||
incomingInput.question,
|
||||
chatHistory,
|
||||
incomingInput?.overrideConfig,
|
||||
sessionId || chatId
|
||||
sessionId || chatId,
|
||||
seqAgentNodes.some((node) => node.data.inputs?.summarization)
|
||||
)
|
||||
} else {
|
||||
isSequential = true
|
||||
|
|
@ -277,6 +279,12 @@ export const buildAgentGraph = async (
|
|||
|
||||
finalSummarization = output[agentName]?.summarization ?? ''
|
||||
|
||||
lastWorkerResult =
|
||||
output[agentName]?.messages?.length &&
|
||||
output[agentName].messages[output[agentName].messages.length - 1]?.additional_kwargs?.type === 'worker'
|
||||
? output[agentName].messages[output[agentName].messages.length - 1].content
|
||||
: lastWorkerResult
|
||||
|
||||
if (socketIO && incomingInput.socketIOClientId) {
|
||||
if (!isStreamingStarted) {
|
||||
isStreamingStarted = true
|
||||
|
|
@ -305,10 +313,13 @@ export const buildAgentGraph = async (
|
|||
|
||||
/*
|
||||
* For multi agents mode, sometimes finalResult is empty
|
||||
* Provide summary as final result
|
||||
* 1.) Provide lastWorkerResult as final result if available
|
||||
* 2.) Provide summary as final result if available
|
||||
*/
|
||||
if (!isSequential && !finalResult && finalSummarization) {
|
||||
finalResult = finalSummarization
|
||||
if (!isSequential && !finalResult) {
|
||||
if (lastWorkerResult) finalResult = lastWorkerResult
|
||||
else if (finalSummarization) finalResult = finalSummarization
|
||||
|
||||
if (socketIO && incomingInput.socketIOClientId) {
|
||||
socketIO.to(incomingInput.socketIOClientId).emit('token', finalResult)
|
||||
}
|
||||
|
|
@ -425,6 +436,7 @@ export const buildAgentGraph = async (
|
|||
* @param {string} question
|
||||
* @param {ICommonObject} overrideConfig
|
||||
* @param {string} threadId
|
||||
* @param {boolean} summarization
|
||||
*/
|
||||
const compileMultiAgentsGraph = async (
|
||||
chatflow: IChatFlow,
|
||||
|
|
@ -437,7 +449,8 @@ const compileMultiAgentsGraph = async (
|
|||
question: string,
|
||||
chatHistory: IMessage[] = [],
|
||||
overrideConfig?: ICommonObject,
|
||||
threadId?: string
|
||||
threadId?: string,
|
||||
summarization?: boolean
|
||||
) => {
|
||||
const appServer = getRunningExpressApp()
|
||||
const channels: ITeamState = {
|
||||
|
|
@ -447,10 +460,11 @@ const compileMultiAgentsGraph = async (
|
|||
},
|
||||
next: 'initialState',
|
||||
instructions: "Solve the user's request.",
|
||||
team_members: [],
|
||||
summarization: 'summarize'
|
||||
team_members: []
|
||||
}
|
||||
|
||||
if (summarization) channels.summarization = 'summarize'
|
||||
|
||||
const workflowGraph = new StateGraph<ITeamState>({
|
||||
//@ts-ignore
|
||||
channels
|
||||
|
|
|
|||
Loading…
Reference in New Issue