Feature/Mistral FunctionAgent (#1912)

* add mistral ai agent, add used tools streaming

* fix AWS Bedrock imports

* update pnpm lock
This commit is contained in:
Henry Heng 2024-03-18 13:17:00 +08:00 committed by GitHub
parent 58122e985c
commit cd4c659009
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 30546 additions and 29817 deletions

View File

@ -9,7 +9,7 @@ import { RunnableSequence } from '@langchain/core/runnables'
import { ChatConversationalAgent } from 'langchain/agents' import { ChatConversationalAgent } from 'langchain/agents'
import { getBaseClasses } from '../../../src/utils' import { getBaseClasses } from '../../../src/utils'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { IVisionChatModal, FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface' import { IVisionChatModal, FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface'
import { AgentExecutor } from '../../../src/agents' import { AgentExecutor } from '../../../src/agents'
import { addImagesToMessages, llmSupportsVision } from '../../../src/multiModalUtils' import { addImagesToMessages, llmSupportsVision } from '../../../src/multiModalUtils'
import { checkInputs, Moderation } from '../../moderation/Moderation' import { checkInputs, Moderation } from '../../moderation/Moderation'
@ -120,12 +120,28 @@ class ConversationalAgent_Agents implements INode {
const callbacks = await additionalCallbacks(nodeData, options) const callbacks = await additionalCallbacks(nodeData, options)
let res: ChainValues = {} let res: ChainValues = {}
let sourceDocuments: ICommonObject[] = []
let usedTools: IUsedTool[] = []
if (options.socketIO && options.socketIOClientId) { if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] }) res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] })
if (res.sourceDocuments) {
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
sourceDocuments = res.sourceDocuments
}
if (res.usedTools) {
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
usedTools = res.usedTools
}
} else { } else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
if (res.sourceDocuments) {
sourceDocuments = res.sourceDocuments
}
if (res.usedTools) {
usedTools = res.usedTools
}
} }
await memory.addChatMessages( await memory.addChatMessages(
@ -142,7 +158,20 @@ class ConversationalAgent_Agents implements INode {
this.sessionId this.sessionId
) )
return res?.output let finalRes = res?.output
if (sourceDocuments.length || usedTools.length) {
finalRes = { text: res?.output }
if (sourceDocuments.length) {
finalRes.sourceDocuments = flatten(sourceDocuments)
}
if (usedTools.length) {
finalRes.usedTools = usedTools
}
return finalRes
}
return finalRes
} }
} }

View File

@ -25,6 +25,7 @@ class ConversationalRetrievalAgent_Agents implements INode {
category: string category: string
baseClasses: string[] baseClasses: string[]
inputs: INodeParams[] inputs: INodeParams[]
badge?: string
sessionId?: string sessionId?: string
constructor(fields?: { sessionId?: string }) { constructor(fields?: { sessionId?: string }) {
@ -33,6 +34,7 @@ class ConversationalRetrievalAgent_Agents implements INode {
this.version = 4.0 this.version = 4.0
this.type = 'AgentExecutor' this.type = 'AgentExecutor'
this.category = 'Agents' this.category = 'Agents'
this.badge = 'DEPRECATING'
this.icon = 'agent.svg' this.icon = 'agent.svg'
this.description = `An agent optimized for retrieval during conversation, answering questions based on past dialogue, all using OpenAI's Function Calling` this.description = `An agent optimized for retrieval during conversation, answering questions based on past dialogue, all using OpenAI's Function Calling`
this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)] this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)]

View File

@ -0,0 +1 @@
<svg width="32" height="32" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M5 6H4v19.5h1m8-7.5v3h1m7-11.5V6h1m-5 7.5V10h1" stroke="#000" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/><mask id="MistralAI__a" style="mask-type:alpha" maskUnits="userSpaceOnUse" x="5" y="6" width="22" height="20"><path d="M5 6v19.5h5v-8h4V21h4v-3.5h4V25h5V6h-4.5v4H18v3.5h-4v-4h-4V6H5Z" fill="#FD7000"/></mask><g mask="url(#MistralAI__a)"><path fill="#FFCD00" d="M4 6h25v4H4z"/></g><mask id="MistralAI__b" style="mask-type:alpha" maskUnits="userSpaceOnUse" x="5" y="6" width="22" height="20"><path d="M5 6v19.5h5v-8h4V21h4v-3.5h4V25h5V6h-4.5v4H18v3.5h-4v-4h-4V6H5Z" fill="#FD7000"/></mask><g mask="url(#MistralAI__b)"><path fill="#FFA200" d="M4 10h25v4H4z"/></g><mask id="MistralAI__c" style="mask-type:alpha" maskUnits="userSpaceOnUse" x="5" y="6" width="22" height="20"><path d="M5 6v19.5h5v-8h4V21h4v-3.5h4V25h5V6h-4.5v4H18v3.5h-4v-4h-4V6H5Z" fill="#FD7000"/></mask><g mask="url(#MistralAI__c)"><path fill="#FF6E00" d="M4 14h25v4H4z"/></g><mask id="MistralAI__d" style="mask-type:alpha" maskUnits="userSpaceOnUse" x="5" y="6" width="22" height="20"><path d="M5 6v19.5h5v-8h4V21h4v-3.5h4V25h5V6h-4.5v4H18v3.5h-4v-4h-4V6H5Z" fill="#FD7000"/></mask><g mask="url(#MistralAI__d)"><path fill="#FF4A09" d="M4 18h25v4H4z"/></g><mask id="MistralAI__e" style="mask-type:alpha" maskUnits="userSpaceOnUse" x="5" y="6" width="22" height="20"><path d="M5 6v19.5h5v-8h4V21h4v-3.5h4V25h5V6h-4.5v4H18v3.5h-4v-4h-4V6H5Z" fill="#FD7000"/></mask><g mask="url(#MistralAI__e)"><path fill="#FE060F" d="M4 22h25v4H4z"/></g><path d="M21 18v7h1" stroke="#000" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/><path d="M5 6v19.5h5v-8h4V21h4v-3.5h4V25h5V6h-4.5v4H18v3.5h-4v-4h-4V6H5Z" stroke="#000" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/></svg>

After

Width:  |  Height:  |  Size: 1.8 KiB

View File

@ -0,0 +1,207 @@
import { flatten } from 'lodash'
import { BaseMessage } from '@langchain/core/messages'
import { ChainValues } from '@langchain/core/utils/types'
import { AgentStep } from '@langchain/core/agents'
import { RunnableSequence } from '@langchain/core/runnables'
import { ChatOpenAI } from '@langchain/openai'
import { convertToOpenAITool } from '@langchain/core/utils/function_calling'
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
import { OpenAIToolsAgentOutputParser } from 'langchain/agents/openai/output_parser'
import { getBaseClasses } from '../../../src/utils'
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { AgentExecutor, formatAgentSteps } from '../../../src/agents'
import { Moderation, checkInputs, streamResponse } from '../../moderation/Moderation'
import { formatResponse } from '../../outputparsers/OutputParserHelpers'
class MistralAIFunctionAgent_Agents implements INode {
label: string
name: string
version: number
description: string
type: string
icon: string
category: string
baseClasses: string[]
inputs: INodeParams[]
sessionId?: string
badge?: string
constructor(fields?: { sessionId?: string }) {
this.label = 'MistralAI Function Agent'
this.name = 'mistralAIFunctionAgent'
this.version = 1.0
this.type = 'AgentExecutor'
this.category = 'Agents'
this.icon = 'MistralAI.svg'
this.badge = 'NEW'
this.description = `An agent that uses MistralAI Function Calling to pick the tool and args to call`
this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)]
this.inputs = [
{
label: 'Tools',
name: 'tools',
type: 'Tool',
list: true
},
{
label: 'Memory',
name: 'memory',
type: 'BaseChatMemory'
},
{
label: 'MistralAI Chat Model',
name: 'model',
type: 'BaseChatModel'
},
{
label: 'System Message',
name: 'systemMessage',
type: 'string',
rows: 4,
optional: true,
additionalParams: true
},
{
label: 'Input Moderation',
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
name: 'inputModeration',
type: 'Moderation',
optional: true,
list: true
}
]
this.sessionId = fields?.sessionId
}
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory)
}
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
const memory = nodeData.inputs?.memory as FlowiseMemory
const moderations = nodeData.inputs?.inputModeration as Moderation[]
if (moderations && moderations.length > 0) {
try {
// Use the output of the moderation chain as input for the OpenAI Function Agent
input = await checkInputs(moderations, input)
} catch (e) {
await new Promise((resolve) => setTimeout(resolve, 500))
streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId)
return formatResponse(e.message)
}
}
const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory)
const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbacks = await additionalCallbacks(nodeData, options)
let res: ChainValues = {}
let sourceDocuments: ICommonObject[] = []
let usedTools: IUsedTool[] = []
if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] })
if (res.sourceDocuments) {
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
sourceDocuments = res.sourceDocuments
}
if (res.usedTools) {
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
usedTools = res.usedTools
}
} else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
if (res.sourceDocuments) {
sourceDocuments = res.sourceDocuments
}
if (res.usedTools) {
usedTools = res.usedTools
}
}
await memory.addChatMessages(
[
{
text: input,
type: 'userMessage'
},
{
text: res?.output,
type: 'apiMessage'
}
],
this.sessionId
)
let finalRes = res?.output
if (sourceDocuments.length || usedTools.length) {
finalRes = { text: res?.output }
if (sourceDocuments.length) {
finalRes.sourceDocuments = flatten(sourceDocuments)
}
if (usedTools.length) {
finalRes.usedTools = usedTools
}
return finalRes
}
return finalRes
}
}
const prepareAgent = (
nodeData: INodeData,
flowObj: { sessionId?: string; chatId?: string; input?: string },
chatHistory: IMessage[] = []
) => {
const model = nodeData.inputs?.model as ChatOpenAI
const memory = nodeData.inputs?.memory as FlowiseMemory
const systemMessage = nodeData.inputs?.systemMessage as string
let tools = nodeData.inputs?.tools
tools = flatten(tools)
const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history'
const inputKey = memory.inputKey ? memory.inputKey : 'input'
const prompt = ChatPromptTemplate.fromMessages([
['system', systemMessage ? systemMessage : `You are a helpful AI assistant.`],
new MessagesPlaceholder(memoryKey),
['human', `{${inputKey}}`],
new MessagesPlaceholder('agent_scratchpad')
])
const llmWithTools = model.bind({
tools: tools.map(convertToOpenAITool)
})
const runnableAgent = RunnableSequence.from([
{
[inputKey]: (i: { input: string; steps: AgentStep[] }) => i.input,
agent_scratchpad: (i: { input: string; steps: AgentStep[] }) => formatAgentSteps(i.steps),
[memoryKey]: async (_: { input: string; steps: AgentStep[] }) => {
const messages = (await memory.getChatMessages(flowObj?.sessionId, true, chatHistory)) as BaseMessage[]
return messages ?? []
}
},
prompt,
llmWithTools,
new OpenAIToolsAgentOutputParser()
])
const executor = AgentExecutor.fromAgentAndTools({
agent: runnableAgent,
tools,
sessionId: flowObj?.sessionId,
chatId: flowObj?.chatId,
input: flowObj?.input,
verbose: process.env.DEBUG === 'true' ? true : false
})
return executor
}
module.exports = { nodeClass: MistralAIFunctionAgent_Agents }

View File

@ -7,7 +7,7 @@ import { ChatOpenAI, formatToOpenAIFunction } from '@langchain/openai'
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts' import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
import { OpenAIFunctionsAgentOutputParser } from 'langchain/agents/openai/output_parser' import { OpenAIFunctionsAgentOutputParser } from 'langchain/agents/openai/output_parser'
import { getBaseClasses } from '../../../src/utils' import { getBaseClasses } from '../../../src/utils'
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface' import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { AgentExecutor, formatAgentSteps } from '../../../src/agents' import { AgentExecutor, formatAgentSteps } from '../../../src/agents'
import { Moderation, checkInputs } from '../../moderation/Moderation' import { Moderation, checkInputs } from '../../moderation/Moderation'
@ -97,6 +97,7 @@ class OpenAIFunctionAgent_Agents implements INode {
let res: ChainValues = {} let res: ChainValues = {}
let sourceDocuments: ICommonObject[] = [] let sourceDocuments: ICommonObject[] = []
let usedTools: IUsedTool[] = []
if (options.socketIO && options.socketIOClientId) { if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
@ -105,11 +106,18 @@ class OpenAIFunctionAgent_Agents implements INode {
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments)) options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
sourceDocuments = res.sourceDocuments sourceDocuments = res.sourceDocuments
} }
if (res.usedTools) {
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
usedTools = res.usedTools
}
} else { } else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
if (res.sourceDocuments) { if (res.sourceDocuments) {
sourceDocuments = res.sourceDocuments sourceDocuments = res.sourceDocuments
} }
if (res.usedTools) {
usedTools = res.usedTools
}
} }
await memory.addChatMessages( await memory.addChatMessages(
@ -126,7 +134,20 @@ class OpenAIFunctionAgent_Agents implements INode {
this.sessionId this.sessionId
) )
return sourceDocuments.length ? { text: res?.output, sourceDocuments: flatten(sourceDocuments) } : res?.output let finalRes = res?.output
if (sourceDocuments.length || usedTools.length) {
finalRes = { text: res?.output }
if (sourceDocuments.length) {
finalRes.sourceDocuments = flatten(sourceDocuments)
}
if (usedTools.length) {
finalRes.usedTools = usedTools
}
return finalRes
}
return finalRes
} }
} }

View File

@ -7,7 +7,7 @@ import { Tool } from '@langchain/core/tools'
import { ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts' import { ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
import { formatLogToMessage } from 'langchain/agents/format_scratchpad/log_to_message' import { formatLogToMessage } from 'langchain/agents/format_scratchpad/log_to_message'
import { getBaseClasses } from '../../../src/utils' import { getBaseClasses } from '../../../src/utils'
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface' import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { AgentExecutor, XMLAgentOutputParser } from '../../../src/agents' import { AgentExecutor, XMLAgentOutputParser } from '../../../src/agents'
import { Moderation, checkInputs } from '../../moderation/Moderation' import { Moderation, checkInputs } from '../../moderation/Moderation'
@ -48,6 +48,7 @@ class XMLAgent_Agents implements INode {
baseClasses: string[] baseClasses: string[]
inputs: INodeParams[] inputs: INodeParams[]
sessionId?: string sessionId?: string
badge?: string
constructor(fields?: { sessionId?: string }) { constructor(fields?: { sessionId?: string }) {
this.label = 'XML Agent' this.label = 'XML Agent'
@ -56,6 +57,7 @@ class XMLAgent_Agents implements INode {
this.type = 'XMLAgent' this.type = 'XMLAgent'
this.category = 'Agents' this.category = 'Agents'
this.icon = 'xmlagent.svg' this.icon = 'xmlagent.svg'
this.badge = 'NEW'
this.description = `Agent that is designed for LLMs that are good for reasoning/writing XML (e.g: Anthropic Claude)` this.description = `Agent that is designed for LLMs that are good for reasoning/writing XML (e.g: Anthropic Claude)`
this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)] this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)]
this.inputs = [ this.inputs = [
@ -121,6 +123,7 @@ class XMLAgent_Agents implements INode {
let res: ChainValues = {} let res: ChainValues = {}
let sourceDocuments: ICommonObject[] = [] let sourceDocuments: ICommonObject[] = []
let usedTools: IUsedTool[] = []
if (options.socketIO && options.socketIOClientId) { if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
@ -129,11 +132,18 @@ class XMLAgent_Agents implements INode {
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments)) options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
sourceDocuments = res.sourceDocuments sourceDocuments = res.sourceDocuments
} }
if (res.usedTools) {
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
usedTools = res.usedTools
}
} else { } else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
if (res.sourceDocuments) { if (res.sourceDocuments) {
sourceDocuments = res.sourceDocuments sourceDocuments = res.sourceDocuments
} }
if (res.usedTools) {
usedTools = res.usedTools
}
} }
await memory.addChatMessages( await memory.addChatMessages(
@ -150,7 +160,20 @@ class XMLAgent_Agents implements INode {
this.sessionId this.sessionId
) )
return sourceDocuments.length ? { text: res?.output, sourceDocuments: flatten(sourceDocuments) } : res?.output let finalRes = res?.output
if (sourceDocuments.length || usedTools.length) {
finalRes = { text: res?.output }
if (sourceDocuments.length) {
finalRes.sourceDocuments = flatten(sourceDocuments)
}
if (usedTools.length) {
finalRes.usedTools = usedTools
}
return finalRes
}
return finalRes
} }
} }

View File

@ -1,8 +1,36 @@
import { ChatCompletionResponse, ToolCalls as MistralAIToolCalls } from '@mistralai/mistralai'
import { BaseCache } from '@langchain/core/caches' import { BaseCache } from '@langchain/core/caches'
import { ChatMistralAI, ChatMistralAIInput } from '@langchain/mistralai' import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
import { NewTokenIndices } from '@langchain/core/callbacks/base'
import { ChatGeneration, ChatGenerationChunk, ChatResult } from '@langchain/core/outputs'
import {
MessageType,
type BaseMessage,
MessageContent,
AIMessage,
HumanMessage,
HumanMessageChunk,
AIMessageChunk,
ToolMessageChunk,
ChatMessageChunk
} from '@langchain/core/messages'
import { ChatMistralAI as LangchainChatMistralAI, ChatMistralAIInput } from '@langchain/mistralai'
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
interface TokenUsage {
completionTokens?: number
promptTokens?: number
totalTokens?: number
}
type MistralAIInputMessage = {
role: string
name?: string
content: string | string[]
tool_calls?: MistralAIToolCalls[] | any[]
}
class ChatMistral_ChatModels implements INode { class ChatMistral_ChatModels implements INode {
label: string label: string
name: string name: string
@ -135,4 +163,243 @@ class ChatMistral_ChatModels implements INode {
} }
} }
class ChatMistralAI extends LangchainChatMistralAI {
async _generate(
messages: BaseMessage[],
options?: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
const tokenUsage: TokenUsage = {}
const params = this.invocationParams(options)
const mistralMessages = this.convertMessagesToMistralMessages(messages)
const input = {
...params,
messages: mistralMessages
}
// Handle streaming
if (this.streaming) {
const stream = this._streamResponseChunks(messages, options, runManager)
const finalChunks: Record<number, ChatGenerationChunk> = {}
for await (const chunk of stream) {
const index = (chunk.generationInfo as NewTokenIndices)?.completion ?? 0
if (finalChunks[index] === undefined) {
finalChunks[index] = chunk
} else {
finalChunks[index] = finalChunks[index].concat(chunk)
}
}
const generations = Object.entries(finalChunks)
.sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10))
.map(([_, value]) => value)
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } }
}
// Not streaming, so we can just call the API once.
const response = await this.completionWithRetry(input, false)
const { completion_tokens: completionTokens, prompt_tokens: promptTokens, total_tokens: totalTokens } = response?.usage ?? {}
if (completionTokens) {
tokenUsage.completionTokens = (tokenUsage.completionTokens ?? 0) + completionTokens
}
if (promptTokens) {
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens
}
if (totalTokens) {
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens
}
const generations: ChatGeneration[] = []
for (const part of response?.choices ?? []) {
if ('delta' in part) {
throw new Error('Delta not supported in non-streaming mode.')
}
if (!('message' in part)) {
throw new Error('No message found in the choice.')
}
const text = part.message?.content ?? ''
const generation: ChatGeneration = {
text,
message: this.mistralAIResponseToChatMessage(part)
}
if (part.finish_reason) {
generation.generationInfo = { finish_reason: part.finish_reason }
}
generations.push(generation)
}
return {
generations,
llmOutput: { tokenUsage }
}
}
async *_streamResponseChunks(
messages: BaseMessage[],
options?: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const mistralMessages = this.convertMessagesToMistralMessages(messages)
const params = this.invocationParams(options)
const input = {
...params,
messages: mistralMessages
}
const streamIterable = await this.completionWithRetry(input, true)
for await (const data of streamIterable) {
const choice = data?.choices[0]
if (!choice || !('delta' in choice)) {
continue
}
const { delta } = choice
if (!delta) {
continue
}
const newTokenIndices = {
prompt: 0,
completion: choice.index ?? 0
}
const message = this._convertDeltaToMessageChunk(delta)
if (message === null) {
// Do not yield a chunk if the message is empty
continue
}
const generationChunk = new ChatGenerationChunk({
message,
text: delta.content ?? '',
generationInfo: newTokenIndices
})
yield generationChunk
void runManager?.handleLLMNewToken(generationChunk.text ?? '', newTokenIndices, undefined, undefined, undefined, {
chunk: generationChunk
})
}
if (options?.signal?.aborted) {
throw new Error('AbortError')
}
}
_convertDeltaToMessageChunk(delta: {
role?: string | undefined
content?: string | undefined
tool_calls?: MistralAIToolCalls[] | undefined
}) {
if (!delta.content && !delta.tool_calls) {
return null
}
// Our merge additional kwargs util function will throw unless there
// is an index key in each tool object (as seen in OpenAI's) so we
// need to insert it here.
const toolCallsWithIndex = delta.tool_calls?.length
? delta.tool_calls?.map((toolCall, index) => ({
...toolCall,
index
}))
: undefined
let role = 'assistant'
if (delta.role) {
role = delta.role
} else if (toolCallsWithIndex) {
role = 'tool'
}
const content = delta.content ?? ''
let additional_kwargs
if (toolCallsWithIndex) {
additional_kwargs = {
tool_calls: toolCallsWithIndex
}
} else {
additional_kwargs = {}
}
if (role === 'user') {
return new HumanMessageChunk({ content })
} else if (role === 'assistant') {
return new AIMessageChunk({ content, additional_kwargs })
} else if (role === 'tool') {
return new ToolMessageChunk({
content,
additional_kwargs,
tool_call_id: toolCallsWithIndex?.[0].id ?? ''
})
} else {
return new ChatMessageChunk({ content, role })
}
}
convertMessagesToMistralMessages(messages: Array<BaseMessage>): Array<MistralAIInputMessage> {
const getRole = (role: MessageType) => {
switch (role) {
case 'human':
return 'user'
case 'ai':
return 'assistant'
case 'tool':
return 'tool'
case 'function':
return 'function'
case 'system':
return 'system'
default:
throw new Error(`Unknown message type: ${role}`)
}
}
const getContent = (content: MessageContent): string => {
if (typeof content === 'string') {
return content
}
throw new Error(`ChatMistralAI does not support non text message content. Received: ${JSON.stringify(content, null, 2)}`)
}
const mistralMessages = []
for (const msg of messages) {
const msgObj: MistralAIInputMessage = {
role: getRole(msg._getType()),
content: getContent(msg.content)
}
if (getRole(msg._getType()) === 'tool') {
msgObj.role = 'assistant'
msgObj.tool_calls = msg.additional_kwargs?.tool_calls ?? []
} else if (getRole(msg._getType()) === 'function') {
msgObj.role = 'tool'
msgObj.name = msg.name
}
mistralMessages.push(msgObj)
}
return mistralMessages
}
mistralAIResponseToChatMessage(choice: ChatCompletionResponse['choices'][0]): BaseMessage {
const { message } = choice
// MistralAI SDK does not include tool_calls in the non
// streaming return type, so we need to extract it like this
// to satisfy typescript.
let toolCalls: MistralAIToolCalls[] = []
if ('tool_calls' in message) {
toolCalls = message.tool_calls as MistralAIToolCalls[]
}
switch (message.role) {
case 'assistant':
return new AIMessage({
content: message.content ?? '',
additional_kwargs: {
tool_calls: toolCalls
}
})
default:
return new HumanMessage(message.content ?? '')
}
}
}
module.exports = { nodeClass: ChatMistral_ChatModels } module.exports = { nodeClass: ChatMistral_ChatModels }

View File

@ -38,9 +38,10 @@
"@langchain/community": "^0.0.39", "@langchain/community": "^0.0.39",
"@langchain/google-genai": "^0.0.10", "@langchain/google-genai": "^0.0.10",
"@langchain/groq": "^0.0.2", "@langchain/groq": "^0.0.2",
"@langchain/mistralai": "^0.0.7", "@langchain/mistralai": "^0.0.11",
"@langchain/openai": "^0.0.14", "@langchain/openai": "^0.0.14",
"@langchain/pinecone": "^0.0.3", "@langchain/pinecone": "^0.0.3",
"@mistralai/mistralai": "0.1.3",
"@notionhq/client": "^2.2.8", "@notionhq/client": "^2.2.8",
"@opensearch-project/opensearch": "^1.2.0", "@opensearch-project/opensearch": "^1.2.0",
"@pinecone-database/pinecone": "^2.0.1", "@pinecone-database/pinecone": "^2.0.1",
@ -70,7 +71,7 @@
"ioredis": "^5.3.2", "ioredis": "^5.3.2",
"jsdom": "^22.1.0", "jsdom": "^22.1.0",
"jsonpointer": "^5.0.1", "jsonpointer": "^5.0.1",
"langchain": "^0.1.20", "langchain": "^0.1.26",
"langfuse": "3.3.1", "langfuse": "3.3.1",
"langfuse-langchain": "^3.3.1", "langfuse-langchain": "^3.3.1",
"langsmith": "0.1.6", "langsmith": "0.1.6",

View File

@ -20,6 +20,7 @@ import {
StoppingMethod StoppingMethod
} from 'langchain/agents' } from 'langchain/agents'
import { formatLogToString } from 'langchain/agents/format_scratchpad/log' import { formatLogToString } from 'langchain/agents/format_scratchpad/log'
import { IUsedTool } from './Interface'
export const SOURCE_DOCUMENTS_PREFIX = '\n\n----FLOWISE_SOURCE_DOCUMENTS----\n\n' export const SOURCE_DOCUMENTS_PREFIX = '\n\n----FLOWISE_SOURCE_DOCUMENTS----\n\n'
type AgentFinish = { type AgentFinish = {
@ -341,11 +342,13 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
const steps: AgentStep[] = [] const steps: AgentStep[] = []
let iterations = 0 let iterations = 0
let sourceDocuments: Array<Document> = [] let sourceDocuments: Array<Document> = []
const usedTools: IUsedTool[] = []
const getOutput = async (finishStep: AgentFinish): Promise<AgentExecutorOutput> => { const getOutput = async (finishStep: AgentFinish): Promise<AgentExecutorOutput> => {
const { returnValues } = finishStep const { returnValues } = finishStep
const additional = await this.agent.prepareForOutput(returnValues, steps) const additional = await this.agent.prepareForOutput(returnValues, steps)
if (sourceDocuments.length) additional.sourceDocuments = flatten(sourceDocuments) if (sourceDocuments.length) additional.sourceDocuments = flatten(sourceDocuments)
if (usedTools.length) additional.usedTools = usedTools
if (this.returnIntermediateSteps) { if (this.returnIntermediateSteps) {
return { ...returnValues, intermediateSteps: steps, ...additional } return { ...returnValues, intermediateSteps: steps, ...additional }
@ -410,18 +413,27 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
* - tags?: string[] * - tags?: string[]
* - flowConfig?: { sessionId?: string, chatId?: string, input?: string } * - flowConfig?: { sessionId?: string, chatId?: string, input?: string }
*/ */
observation = tool if (tool) {
? await (tool as any).call( observation = await (tool as any).call(
this.isXML && typeof action.toolInput === 'string' ? { input: action.toolInput } : action.toolInput, this.isXML && typeof action.toolInput === 'string' ? { input: action.toolInput } : action.toolInput,
runManager?.getChild(), runManager?.getChild(),
undefined, undefined,
{ {
sessionId: this.sessionId, sessionId: this.sessionId,
chatId: this.chatId, chatId: this.chatId,
input: this.input input: this.input
} }
) )
: `${action.tool} is not a valid tool, try another one.` usedTools.push({
tool: tool.name,
toolInput: action.toolInput as any,
toolOutput: observation.includes(SOURCE_DOCUMENTS_PREFIX)
? observation.split(SOURCE_DOCUMENTS_PREFIX)[0]
: observation
})
} else {
observation = `${action.tool} is not a valid tool, try another one.`
}
} catch (e) { } catch (e) {
if (e instanceof ToolInputParsingException) { if (e instanceof ToolInputParsingException) {
if (this.handleParsingErrors === true) { if (this.handleParsingErrors === true) {

View File

@ -905,7 +905,13 @@ export const isFlowValidForStream = (reactFlowNodes: IReactFlowNode[], endingNod
isValidChainOrAgent = !blacklistChains.includes(endingNodeData.name) isValidChainOrAgent = !blacklistChains.includes(endingNodeData.name)
} else if (endingNodeData.category === 'Agents') { } else if (endingNodeData.category === 'Agents') {
// Agent that are available to stream // Agent that are available to stream
const whitelistAgents = ['openAIFunctionAgent', 'csvAgent', 'airtableAgent', 'conversationalRetrievalAgent'] const whitelistAgents = [
'openAIFunctionAgent',
'mistralAIFunctionAgent',
'csvAgent',
'airtableAgent',
'conversationalRetrievalAgent'
]
isValidChainOrAgent = whitelistAgents.includes(endingNodeData.name) isValidChainOrAgent = whitelistAgents.includes(endingNodeData.name)
} else if (endingNodeData.category === 'Engine') { } else if (endingNodeData.category === 'Engine') {
const whitelistEngine = ['contextChatEngine', 'simpleChatEngine', 'queryEngine', 'subQuestionQueryEngine'] const whitelistEngine = ['contextChatEngine', 'simpleChatEngine', 'queryEngine', 'subQuestionQueryEngine']

View File

@ -23,7 +23,7 @@ import {
Typography Typography
} from '@mui/material' } from '@mui/material'
import { useTheme } from '@mui/material/styles' import { useTheme } from '@mui/material/styles'
import { IconCircleDot, IconDownload, IconSend, IconMicrophone, IconPhotoPlus, IconTrash, IconX } from '@tabler/icons' import { IconCircleDot, IconDownload, IconSend, IconMicrophone, IconPhotoPlus, IconTrash, IconX, IconTool } from '@tabler/icons'
import robotPNG from '@/assets/images/robot.png' import robotPNG from '@/assets/images/robot.png'
import userPNG from '@/assets/images/account.png' import userPNG from '@/assets/images/account.png'
import audioUploadSVG from '@/assets/images/wave-sound.jpg' import audioUploadSVG from '@/assets/images/wave-sound.jpg'
@ -340,6 +340,15 @@ export const ChatMessage = ({ open, chatflowid, isDialog, previews, setPreviews
}) })
} }
const updateLastMessageUsedTools = (usedTools) => {
setMessages((prevMessages) => {
let allMessages = [...cloneDeep(prevMessages)]
if (allMessages[allMessages.length - 1].type === 'userMessage') return allMessages
allMessages[allMessages.length - 1].usedTools = usedTools
return allMessages
})
}
// Handle errors // Handle errors
const handleError = (message = 'Oops! There seems to be an error. Please try again.') => { const handleError = (message = 'Oops! There seems to be an error. Please try again.') => {
message = message.replace(`Unable to parse JSON response from chat agent.\n\n`, '') message = message.replace(`Unable to parse JSON response from chat agent.\n\n`, '')
@ -596,6 +605,8 @@ export const ChatMessage = ({ open, chatflowid, isDialog, previews, setPreviews
socket.on('sourceDocuments', updateLastMessageSourceDocuments) socket.on('sourceDocuments', updateLastMessageSourceDocuments)
socket.on('usedTools', updateLastMessageUsedTools)
socket.on('token', updateLastMessage) socket.on('token', updateLastMessage)
} }
@ -770,6 +781,7 @@ export const ChatMessage = ({ open, chatflowid, isDialog, previews, setPreviews
sx={{ mr: 1, mt: 1 }} sx={{ mr: 1, mt: 1 }}
variant='outlined' variant='outlined'
clickable clickable
icon={<IconTool size={15} />}
onClick={() => onSourceDialogClick(tool, 'Used Tools')} onClick={() => onSourceDialogClick(tool, 'Used Tools')}
/> />
) )

File diff suppressed because it is too large Load Diff

View File

@ -1,2 +1,2 @@
packages: packages:
- 'packages/*' - 'packages/*'