Bugfix/Move summarization as optional for multi agents (#2858)

add summarization as optional for multi agents
This commit is contained in:
Henry Heng 2024-07-23 15:15:41 +01:00 committed by GitHub
parent c31a4c95e7
commit 368c69cbc5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 290 additions and 13 deletions

View File

@ -49,7 +49,7 @@ class Supervisor_MultiAgents implements INode {
constructor() {
this.label = 'Supervisor'
this.name = 'supervisor'
this.version = 2.0
this.version = 3.0
this.type = 'Supervisor'
this.icon = 'supervisor.svg'
this.category = 'Multi Agents'
@ -84,6 +84,14 @@ class Supervisor_MultiAgents implements INode {
description: 'Save the state of the agent',
optional: true
},
{
label: 'Summarization',
name: 'summarization',
type: 'boolean',
description: 'Return final output as a summarization of the conversation',
optional: true,
additionalParams: true
},
{
label: 'Recursion Limit',
name: 'recursionLimit',
@ -110,6 +118,7 @@ class Supervisor_MultiAgents implements INode {
const _recursionLimit = nodeData.inputs?.recursionLimit as string
const recursionLimit = _recursionLimit ? parseFloat(_recursionLimit) : 100
const moderations = (nodeData.inputs?.inputModeration as Moderation[]) ?? []
const summarization = nodeData.inputs?.summarization as string
const abortControllerSignal = options.signal as AbortController
@ -128,6 +137,257 @@ class Supervisor_MultiAgents implements INode {
systemPrompt = systemPrompt.replaceAll('{team_members}', members.join(', '))
let userPrompt = `Given the conversation above, who should act next? Or should we FINISH? Select one of: ${memberOptions.join(
', '
)}`
const tool = new RouteTool({
schema: z.object({
reasoning: z.string(),
next: z.enum(['FINISH', ...members]),
instructions: z.string().describe('The specific instructions of the sub-task the next role should accomplish.')
})
})
let supervisor
if (llm instanceof ChatMistralAI) {
let prompt = ChatPromptTemplate.fromMessages([
['system', systemPrompt],
new MessagesPlaceholder('messages'),
['human', userPrompt]
])
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
prompt = messages.prompt
multiModalMessageContent = messages.multiModalMessageContent
// Force Mistral to use tool
// @ts-ignore
const modelWithTool = llm.bind({
tools: [tool],
tool_choice: 'any',
signal: abortControllerSignal ? abortControllerSignal.signal : undefined
})
const outputParser = new JsonOutputToolsParser()
supervisor = prompt
.pipe(modelWithTool)
.pipe(outputParser)
.pipe((x) => {
if (Array.isArray(x) && x.length) {
const toolAgentAction = x[0]
return {
next: Object.keys(toolAgentAction.args).length ? toolAgentAction.args.next : 'FINISH',
instructions: Object.keys(toolAgentAction.args).length
? toolAgentAction.args.instructions
: 'Conversation finished',
team_members: members.join(', ')
}
} else {
return {
next: 'FINISH',
instructions: 'Conversation finished',
team_members: members.join(', ')
}
}
})
} else if (llm instanceof ChatAnthropic) {
// Force Anthropic to use tool : https://docs.anthropic.com/claude/docs/tool-use#forcing-tool-use
userPrompt = `Given the conversation above, who should act next? Or should we FINISH? Select one of: ${memberOptions.join(
', '
)}. Use the ${routerToolName} tool in your response.`
let prompt = ChatPromptTemplate.fromMessages([
['system', systemPrompt],
new MessagesPlaceholder('messages'),
['human', userPrompt]
])
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
prompt = messages.prompt
multiModalMessageContent = messages.multiModalMessageContent
if (llm.bindTools === undefined) {
throw new Error(`This agent only compatible with function calling models.`)
}
const modelWithTool = llm.bindTools([tool])
const outputParser = new ToolCallingAgentOutputParser()
supervisor = prompt
.pipe(modelWithTool)
.pipe(outputParser)
.pipe((x) => {
if (Array.isArray(x) && x.length) {
const toolAgentAction = x[0] as any
return {
next: toolAgentAction.toolInput.next,
instructions: toolAgentAction.toolInput.instructions,
team_members: members.join(', ')
}
} else if (typeof x === 'object' && 'returnValues' in x) {
return {
next: 'FINISH',
instructions: x.returnValues?.output,
team_members: members.join(', ')
}
} else {
return {
next: 'FINISH',
instructions: 'Conversation finished',
team_members: members.join(', ')
}
}
})
} else if (llm instanceof ChatOpenAI) {
let prompt = ChatPromptTemplate.fromMessages([
['system', systemPrompt],
new MessagesPlaceholder('messages'),
['human', userPrompt]
])
// @ts-ignore
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
prompt = messages.prompt
multiModalMessageContent = messages.multiModalMessageContent
// Force OpenAI to use tool
const modelWithTool = llm.bind({
tools: [tool],
tool_choice: { type: 'function', function: { name: routerToolName } },
signal: abortControllerSignal ? abortControllerSignal.signal : undefined
})
const outputParser = new ToolCallingAgentOutputParser()
supervisor = prompt
.pipe(modelWithTool)
.pipe(outputParser)
.pipe((x) => {
if (Array.isArray(x) && x.length) {
const toolAgentAction = x[0] as any
return {
next: toolAgentAction.toolInput.next,
instructions: toolAgentAction.toolInput.instructions,
team_members: members.join(', ')
}
} else if (typeof x === 'object' && 'returnValues' in x) {
return {
next: 'FINISH',
instructions: x.returnValues?.output,
team_members: members.join(', ')
}
} else {
return {
next: 'FINISH',
instructions: 'Conversation finished',
team_members: members.join(', ')
}
}
})
} else if (llm instanceof ChatGoogleGenerativeAI) {
/*
* Gemini doesn't have system message and messages have to be alternate between model and user
* So we have to place the system + human prompt at last
*/
let prompt = ChatPromptTemplate.fromMessages([
['system', systemPrompt],
new MessagesPlaceholder('messages'),
['human', userPrompt]
])
const messages = await processImageMessage(2, llm, prompt, nodeData, options)
prompt = messages.prompt
multiModalMessageContent = messages.multiModalMessageContent
if (llm.bindTools === undefined) {
throw new Error(`This agent only compatible with function calling models.`)
}
const modelWithTool = llm.bindTools([tool])
const outputParser = new ToolCallingAgentOutputParser()
supervisor = prompt
.pipe(modelWithTool)
.pipe(outputParser)
.pipe((x) => {
if (Array.isArray(x) && x.length) {
const toolAgentAction = x[0] as any
return {
next: toolAgentAction.toolInput.next,
instructions: toolAgentAction.toolInput.instructions,
team_members: members.join(', ')
}
} else if (typeof x === 'object' && 'returnValues' in x) {
return {
next: 'FINISH',
instructions: x.returnValues?.output,
team_members: members.join(', ')
}
} else {
return {
next: 'FINISH',
instructions: 'Conversation finished',
team_members: members.join(', ')
}
}
})
} else {
let prompt = ChatPromptTemplate.fromMessages([
['system', systemPrompt],
new MessagesPlaceholder('messages'),
['human', userPrompt]
])
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
prompt = messages.prompt
multiModalMessageContent = messages.multiModalMessageContent
if (llm.bindTools === undefined) {
throw new Error(`This agent only compatible with function calling models.`)
}
const modelWithTool = llm.bindTools([tool])
const outputParser = new ToolCallingAgentOutputParser()
supervisor = prompt
.pipe(modelWithTool)
.pipe(outputParser)
.pipe((x) => {
if (Array.isArray(x) && x.length) {
const toolAgentAction = x[0] as any
return {
next: toolAgentAction.toolInput.next,
instructions: toolAgentAction.toolInput.instructions,
team_members: members.join(', ')
}
} else if (typeof x === 'object' && 'returnValues' in x) {
return {
next: 'FINISH',
instructions: x.returnValues?.output,
team_members: members.join(', ')
}
} else {
return {
next: 'FINISH',
instructions: 'Conversation finished',
team_members: members.join(', ')
}
}
})
}
return supervisor
}
async function createTeamSupervisorWithSummarize(llm: BaseChatModel, systemPrompt: string, members: string[]): Promise<Runnable> {
const memberOptions = ['FINISH', ...members]
systemPrompt = systemPrompt.replaceAll('{team_members}', members.join(', '))
let userPrompt = `Given the conversation above, who should act next? Or should we FINISH? Select one of: ${memberOptions.join(
', '
)}
@ -247,7 +507,8 @@ class Supervisor_MultiAgents implements INode {
['human', userPrompt]
])
const messages = await processImageMessage(1, llm as any, prompt, nodeData, options)
// @ts-ignore
const messages = await processImageMessage(1, llm, prompt, nodeData, options)
prompt = messages.prompt
multiModalMessageContent = messages.multiModalMessageContent
@ -389,7 +650,9 @@ class Supervisor_MultiAgents implements INode {
return supervisor
}
const supervisorAgent = await createTeamSupervisor(llm, supervisorPrompt ? supervisorPrompt : sysPrompt, workersNodeNames)
const supervisorAgent = summarization
? await createTeamSupervisorWithSummarize(llm, supervisorPrompt ? supervisorPrompt : sysPrompt, workersNodeNames)
: await createTeamSupervisor(llm, supervisorPrompt ? supervisorPrompt : sysPrompt, workersNodeNames)
const supervisorNode = async (state: ITeamState, config: RunnableConfig) =>
await agentNode(
@ -433,7 +696,7 @@ async function agentNode(
throw new Error('Aborted!')
}
const result = await agent.invoke({ ...state, signal: abortControllerSignal.signal }, config)
const additional_kwargs: ICommonObject = { nodeId }
const additional_kwargs: ICommonObject = { nodeId, type: 'supervisor' }
result.additional_kwargs = { ...result.additional_kwargs, ...additional_kwargs }
return result
} catch (error) {

View File

@ -283,7 +283,7 @@ async function agentNode(
}
const result = await agent.invoke({ ...state, signal: abortControllerSignal.signal }, config)
const additional_kwargs: ICommonObject = { nodeId }
const additional_kwargs: ICommonObject = { nodeId, type: 'worker' }
if (result.usedTools) {
additional_kwargs.usedTools = result.usedTools
}

View File

@ -206,7 +206,7 @@ export interface ITeamState {
team_members: string[]
next: string
instructions: string
summarization: string
summarization?: string
}
export interface ISeqAgentsState {

View File

@ -147,6 +147,7 @@ export const buildAgentGraph = async (
let streamResults
let finalResult = ''
let finalSummarization = ''
let lastWorkerResult = ''
let agentReasoning: IAgentReasoning[] = []
let isSequential = false
let lastMessageRaw = {} as AIMessageChunk
@ -182,7 +183,8 @@ export const buildAgentGraph = async (
incomingInput.question,
chatHistory,
incomingInput?.overrideConfig,
sessionId || chatId
sessionId || chatId,
seqAgentNodes.some((node) => node.data.inputs?.summarization)
)
} else {
isSequential = true
@ -277,6 +279,12 @@ export const buildAgentGraph = async (
finalSummarization = output[agentName]?.summarization ?? ''
lastWorkerResult =
output[agentName]?.messages?.length &&
output[agentName].messages[output[agentName].messages.length - 1]?.additional_kwargs?.type === 'worker'
? output[agentName].messages[output[agentName].messages.length - 1].content
: lastWorkerResult
if (socketIO && incomingInput.socketIOClientId) {
if (!isStreamingStarted) {
isStreamingStarted = true
@ -305,10 +313,13 @@ export const buildAgentGraph = async (
/*
* For multi agents mode, sometimes finalResult is empty
* Provide summary as final result
* 1.) Provide lastWorkerResult as final result if available
* 2.) Provide summary as final result if available
*/
if (!isSequential && !finalResult && finalSummarization) {
finalResult = finalSummarization
if (!isSequential && !finalResult) {
if (lastWorkerResult) finalResult = lastWorkerResult
else if (finalSummarization) finalResult = finalSummarization
if (socketIO && incomingInput.socketIOClientId) {
socketIO.to(incomingInput.socketIOClientId).emit('token', finalResult)
}
@ -425,6 +436,7 @@ export const buildAgentGraph = async (
* @param {string} question
* @param {ICommonObject} overrideConfig
* @param {string} threadId
* @param {boolean} summarization
*/
const compileMultiAgentsGraph = async (
chatflow: IChatFlow,
@ -437,7 +449,8 @@ const compileMultiAgentsGraph = async (
question: string,
chatHistory: IMessage[] = [],
overrideConfig?: ICommonObject,
threadId?: string
threadId?: string,
summarization?: boolean
) => {
const appServer = getRunningExpressApp()
const channels: ITeamState = {
@ -447,10 +460,11 @@ const compileMultiAgentsGraph = async (
},
next: 'initialState',
instructions: "Solve the user's request.",
team_members: [],
summarization: 'summarize'
team_members: []
}
if (summarization) channels.summarization = 'summarize'
const workflowGraph = new StateGraph<ITeamState>({
//@ts-ignore
channels