Feature/Mistral FunctionAgent (#1912)
* add mistral ai agent, add used tools streaming * fix AWS Bedrock imports * update pnpm lock
This commit is contained in:
parent
58122e985c
commit
cd4c659009
|
|
@ -9,7 +9,7 @@ import { RunnableSequence } from '@langchain/core/runnables'
|
|||
import { ChatConversationalAgent } from 'langchain/agents'
|
||||
import { getBaseClasses } from '../../../src/utils'
|
||||
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
|
||||
import { IVisionChatModal, FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
|
||||
import { IVisionChatModal, FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface'
|
||||
import { AgentExecutor } from '../../../src/agents'
|
||||
import { addImagesToMessages, llmSupportsVision } from '../../../src/multiModalUtils'
|
||||
import { checkInputs, Moderation } from '../../moderation/Moderation'
|
||||
|
|
@ -120,12 +120,28 @@ class ConversationalAgent_Agents implements INode {
|
|||
const callbacks = await additionalCallbacks(nodeData, options)
|
||||
|
||||
let res: ChainValues = {}
|
||||
let sourceDocuments: ICommonObject[] = []
|
||||
let usedTools: IUsedTool[] = []
|
||||
|
||||
if (options.socketIO && options.socketIOClientId) {
|
||||
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
|
||||
res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] })
|
||||
if (res.sourceDocuments) {
|
||||
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
|
||||
sourceDocuments = res.sourceDocuments
|
||||
}
|
||||
if (res.usedTools) {
|
||||
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
|
||||
usedTools = res.usedTools
|
||||
}
|
||||
} else {
|
||||
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
|
||||
if (res.sourceDocuments) {
|
||||
sourceDocuments = res.sourceDocuments
|
||||
}
|
||||
if (res.usedTools) {
|
||||
usedTools = res.usedTools
|
||||
}
|
||||
}
|
||||
|
||||
await memory.addChatMessages(
|
||||
|
|
@ -142,7 +158,20 @@ class ConversationalAgent_Agents implements INode {
|
|||
this.sessionId
|
||||
)
|
||||
|
||||
return res?.output
|
||||
let finalRes = res?.output
|
||||
|
||||
if (sourceDocuments.length || usedTools.length) {
|
||||
finalRes = { text: res?.output }
|
||||
if (sourceDocuments.length) {
|
||||
finalRes.sourceDocuments = flatten(sourceDocuments)
|
||||
}
|
||||
if (usedTools.length) {
|
||||
finalRes.usedTools = usedTools
|
||||
}
|
||||
return finalRes
|
||||
}
|
||||
|
||||
return finalRes
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ class ConversationalRetrievalAgent_Agents implements INode {
|
|||
category: string
|
||||
baseClasses: string[]
|
||||
inputs: INodeParams[]
|
||||
badge?: string
|
||||
sessionId?: string
|
||||
|
||||
constructor(fields?: { sessionId?: string }) {
|
||||
|
|
@ -33,6 +34,7 @@ class ConversationalRetrievalAgent_Agents implements INode {
|
|||
this.version = 4.0
|
||||
this.type = 'AgentExecutor'
|
||||
this.category = 'Agents'
|
||||
this.badge = 'DEPRECATING'
|
||||
this.icon = 'agent.svg'
|
||||
this.description = `An agent optimized for retrieval during conversation, answering questions based on past dialogue, all using OpenAI's Function Calling`
|
||||
this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)]
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
<svg width="32" height="32" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M5 6H4v19.5h1m8-7.5v3h1m7-11.5V6h1m-5 7.5V10h1" stroke="#000" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/><mask id="MistralAI__a" style="mask-type:alpha" maskUnits="userSpaceOnUse" x="5" y="6" width="22" height="20"><path d="M5 6v19.5h5v-8h4V21h4v-3.5h4V25h5V6h-4.5v4H18v3.5h-4v-4h-4V6H5Z" fill="#FD7000"/></mask><g mask="url(#MistralAI__a)"><path fill="#FFCD00" d="M4 6h25v4H4z"/></g><mask id="MistralAI__b" style="mask-type:alpha" maskUnits="userSpaceOnUse" x="5" y="6" width="22" height="20"><path d="M5 6v19.5h5v-8h4V21h4v-3.5h4V25h5V6h-4.5v4H18v3.5h-4v-4h-4V6H5Z" fill="#FD7000"/></mask><g mask="url(#MistralAI__b)"><path fill="#FFA200" d="M4 10h25v4H4z"/></g><mask id="MistralAI__c" style="mask-type:alpha" maskUnits="userSpaceOnUse" x="5" y="6" width="22" height="20"><path d="M5 6v19.5h5v-8h4V21h4v-3.5h4V25h5V6h-4.5v4H18v3.5h-4v-4h-4V6H5Z" fill="#FD7000"/></mask><g mask="url(#MistralAI__c)"><path fill="#FF6E00" d="M4 14h25v4H4z"/></g><mask id="MistralAI__d" style="mask-type:alpha" maskUnits="userSpaceOnUse" x="5" y="6" width="22" height="20"><path d="M5 6v19.5h5v-8h4V21h4v-3.5h4V25h5V6h-4.5v4H18v3.5h-4v-4h-4V6H5Z" fill="#FD7000"/></mask><g mask="url(#MistralAI__d)"><path fill="#FF4A09" d="M4 18h25v4H4z"/></g><mask id="MistralAI__e" style="mask-type:alpha" maskUnits="userSpaceOnUse" x="5" y="6" width="22" height="20"><path d="M5 6v19.5h5v-8h4V21h4v-3.5h4V25h5V6h-4.5v4H18v3.5h-4v-4h-4V6H5Z" fill="#FD7000"/></mask><g mask="url(#MistralAI__e)"><path fill="#FE060F" d="M4 22h25v4H4z"/></g><path d="M21 18v7h1" stroke="#000" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/><path d="M5 6v19.5h5v-8h4V21h4v-3.5h4V25h5V6h-4.5v4H18v3.5h-4v-4h-4V6H5Z" stroke="#000" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/></svg>
|
||||
|
After Width: | Height: | Size: 1.8 KiB |
|
|
@ -0,0 +1,207 @@
|
|||
import { flatten } from 'lodash'
|
||||
import { BaseMessage } from '@langchain/core/messages'
|
||||
import { ChainValues } from '@langchain/core/utils/types'
|
||||
import { AgentStep } from '@langchain/core/agents'
|
||||
import { RunnableSequence } from '@langchain/core/runnables'
|
||||
import { ChatOpenAI } from '@langchain/openai'
|
||||
import { convertToOpenAITool } from '@langchain/core/utils/function_calling'
|
||||
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
|
||||
import { OpenAIToolsAgentOutputParser } from 'langchain/agents/openai/output_parser'
|
||||
import { getBaseClasses } from '../../../src/utils'
|
||||
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface'
|
||||
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
|
||||
import { AgentExecutor, formatAgentSteps } from '../../../src/agents'
|
||||
import { Moderation, checkInputs, streamResponse } from '../../moderation/Moderation'
|
||||
import { formatResponse } from '../../outputparsers/OutputParserHelpers'
|
||||
|
||||
class MistralAIFunctionAgent_Agents implements INode {
|
||||
label: string
|
||||
name: string
|
||||
version: number
|
||||
description: string
|
||||
type: string
|
||||
icon: string
|
||||
category: string
|
||||
baseClasses: string[]
|
||||
inputs: INodeParams[]
|
||||
sessionId?: string
|
||||
badge?: string
|
||||
|
||||
constructor(fields?: { sessionId?: string }) {
|
||||
this.label = 'MistralAI Function Agent'
|
||||
this.name = 'mistralAIFunctionAgent'
|
||||
this.version = 1.0
|
||||
this.type = 'AgentExecutor'
|
||||
this.category = 'Agents'
|
||||
this.icon = 'MistralAI.svg'
|
||||
this.badge = 'NEW'
|
||||
this.description = `An agent that uses MistralAI Function Calling to pick the tool and args to call`
|
||||
this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)]
|
||||
this.inputs = [
|
||||
{
|
||||
label: 'Tools',
|
||||
name: 'tools',
|
||||
type: 'Tool',
|
||||
list: true
|
||||
},
|
||||
{
|
||||
label: 'Memory',
|
||||
name: 'memory',
|
||||
type: 'BaseChatMemory'
|
||||
},
|
||||
{
|
||||
label: 'MistralAI Chat Model',
|
||||
name: 'model',
|
||||
type: 'BaseChatModel'
|
||||
},
|
||||
{
|
||||
label: 'System Message',
|
||||
name: 'systemMessage',
|
||||
type: 'string',
|
||||
rows: 4,
|
||||
optional: true,
|
||||
additionalParams: true
|
||||
},
|
||||
{
|
||||
label: 'Input Moderation',
|
||||
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
|
||||
name: 'inputModeration',
|
||||
type: 'Moderation',
|
||||
optional: true,
|
||||
list: true
|
||||
}
|
||||
]
|
||||
this.sessionId = fields?.sessionId
|
||||
}
|
||||
|
||||
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
|
||||
return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory)
|
||||
}
|
||||
|
||||
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
|
||||
const memory = nodeData.inputs?.memory as FlowiseMemory
|
||||
const moderations = nodeData.inputs?.inputModeration as Moderation[]
|
||||
|
||||
if (moderations && moderations.length > 0) {
|
||||
try {
|
||||
// Use the output of the moderation chain as input for the OpenAI Function Agent
|
||||
input = await checkInputs(moderations, input)
|
||||
} catch (e) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||||
streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId)
|
||||
return formatResponse(e.message)
|
||||
}
|
||||
}
|
||||
|
||||
const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory)
|
||||
|
||||
const loggerHandler = new ConsoleCallbackHandler(options.logger)
|
||||
const callbacks = await additionalCallbacks(nodeData, options)
|
||||
|
||||
let res: ChainValues = {}
|
||||
let sourceDocuments: ICommonObject[] = []
|
||||
let usedTools: IUsedTool[] = []
|
||||
|
||||
if (options.socketIO && options.socketIOClientId) {
|
||||
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
|
||||
res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] })
|
||||
if (res.sourceDocuments) {
|
||||
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
|
||||
sourceDocuments = res.sourceDocuments
|
||||
}
|
||||
if (res.usedTools) {
|
||||
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
|
||||
usedTools = res.usedTools
|
||||
}
|
||||
} else {
|
||||
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
|
||||
if (res.sourceDocuments) {
|
||||
sourceDocuments = res.sourceDocuments
|
||||
}
|
||||
if (res.usedTools) {
|
||||
usedTools = res.usedTools
|
||||
}
|
||||
}
|
||||
|
||||
await memory.addChatMessages(
|
||||
[
|
||||
{
|
||||
text: input,
|
||||
type: 'userMessage'
|
||||
},
|
||||
{
|
||||
text: res?.output,
|
||||
type: 'apiMessage'
|
||||
}
|
||||
],
|
||||
this.sessionId
|
||||
)
|
||||
|
||||
let finalRes = res?.output
|
||||
|
||||
if (sourceDocuments.length || usedTools.length) {
|
||||
finalRes = { text: res?.output }
|
||||
if (sourceDocuments.length) {
|
||||
finalRes.sourceDocuments = flatten(sourceDocuments)
|
||||
}
|
||||
if (usedTools.length) {
|
||||
finalRes.usedTools = usedTools
|
||||
}
|
||||
return finalRes
|
||||
}
|
||||
|
||||
return finalRes
|
||||
}
|
||||
}
|
||||
|
||||
const prepareAgent = (
|
||||
nodeData: INodeData,
|
||||
flowObj: { sessionId?: string; chatId?: string; input?: string },
|
||||
chatHistory: IMessage[] = []
|
||||
) => {
|
||||
const model = nodeData.inputs?.model as ChatOpenAI
|
||||
const memory = nodeData.inputs?.memory as FlowiseMemory
|
||||
const systemMessage = nodeData.inputs?.systemMessage as string
|
||||
let tools = nodeData.inputs?.tools
|
||||
tools = flatten(tools)
|
||||
const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history'
|
||||
const inputKey = memory.inputKey ? memory.inputKey : 'input'
|
||||
|
||||
const prompt = ChatPromptTemplate.fromMessages([
|
||||
['system', systemMessage ? systemMessage : `You are a helpful AI assistant.`],
|
||||
new MessagesPlaceholder(memoryKey),
|
||||
['human', `{${inputKey}}`],
|
||||
new MessagesPlaceholder('agent_scratchpad')
|
||||
])
|
||||
|
||||
const llmWithTools = model.bind({
|
||||
tools: tools.map(convertToOpenAITool)
|
||||
})
|
||||
|
||||
const runnableAgent = RunnableSequence.from([
|
||||
{
|
||||
[inputKey]: (i: { input: string; steps: AgentStep[] }) => i.input,
|
||||
agent_scratchpad: (i: { input: string; steps: AgentStep[] }) => formatAgentSteps(i.steps),
|
||||
[memoryKey]: async (_: { input: string; steps: AgentStep[] }) => {
|
||||
const messages = (await memory.getChatMessages(flowObj?.sessionId, true, chatHistory)) as BaseMessage[]
|
||||
return messages ?? []
|
||||
}
|
||||
},
|
||||
prompt,
|
||||
llmWithTools,
|
||||
new OpenAIToolsAgentOutputParser()
|
||||
])
|
||||
|
||||
const executor = AgentExecutor.fromAgentAndTools({
|
||||
agent: runnableAgent,
|
||||
tools,
|
||||
sessionId: flowObj?.sessionId,
|
||||
chatId: flowObj?.chatId,
|
||||
input: flowObj?.input,
|
||||
verbose: process.env.DEBUG === 'true' ? true : false
|
||||
})
|
||||
|
||||
return executor
|
||||
}
|
||||
|
||||
module.exports = { nodeClass: MistralAIFunctionAgent_Agents }
|
||||
|
|
@ -7,7 +7,7 @@ import { ChatOpenAI, formatToOpenAIFunction } from '@langchain/openai'
|
|||
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
|
||||
import { OpenAIFunctionsAgentOutputParser } from 'langchain/agents/openai/output_parser'
|
||||
import { getBaseClasses } from '../../../src/utils'
|
||||
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
|
||||
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface'
|
||||
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
|
||||
import { AgentExecutor, formatAgentSteps } from '../../../src/agents'
|
||||
import { Moderation, checkInputs } from '../../moderation/Moderation'
|
||||
|
|
@ -97,6 +97,7 @@ class OpenAIFunctionAgent_Agents implements INode {
|
|||
|
||||
let res: ChainValues = {}
|
||||
let sourceDocuments: ICommonObject[] = []
|
||||
let usedTools: IUsedTool[] = []
|
||||
|
||||
if (options.socketIO && options.socketIOClientId) {
|
||||
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
|
||||
|
|
@ -105,11 +106,18 @@ class OpenAIFunctionAgent_Agents implements INode {
|
|||
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
|
||||
sourceDocuments = res.sourceDocuments
|
||||
}
|
||||
if (res.usedTools) {
|
||||
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
|
||||
usedTools = res.usedTools
|
||||
}
|
||||
} else {
|
||||
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
|
||||
if (res.sourceDocuments) {
|
||||
sourceDocuments = res.sourceDocuments
|
||||
}
|
||||
if (res.usedTools) {
|
||||
usedTools = res.usedTools
|
||||
}
|
||||
}
|
||||
|
||||
await memory.addChatMessages(
|
||||
|
|
@ -126,7 +134,20 @@ class OpenAIFunctionAgent_Agents implements INode {
|
|||
this.sessionId
|
||||
)
|
||||
|
||||
return sourceDocuments.length ? { text: res?.output, sourceDocuments: flatten(sourceDocuments) } : res?.output
|
||||
let finalRes = res?.output
|
||||
|
||||
if (sourceDocuments.length || usedTools.length) {
|
||||
finalRes = { text: res?.output }
|
||||
if (sourceDocuments.length) {
|
||||
finalRes.sourceDocuments = flatten(sourceDocuments)
|
||||
}
|
||||
if (usedTools.length) {
|
||||
finalRes.usedTools = usedTools
|
||||
}
|
||||
return finalRes
|
||||
}
|
||||
|
||||
return finalRes
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import { Tool } from '@langchain/core/tools'
|
|||
import { ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
|
||||
import { formatLogToMessage } from 'langchain/agents/format_scratchpad/log_to_message'
|
||||
import { getBaseClasses } from '../../../src/utils'
|
||||
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
|
||||
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface'
|
||||
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
|
||||
import { AgentExecutor, XMLAgentOutputParser } from '../../../src/agents'
|
||||
import { Moderation, checkInputs } from '../../moderation/Moderation'
|
||||
|
|
@ -48,6 +48,7 @@ class XMLAgent_Agents implements INode {
|
|||
baseClasses: string[]
|
||||
inputs: INodeParams[]
|
||||
sessionId?: string
|
||||
badge?: string
|
||||
|
||||
constructor(fields?: { sessionId?: string }) {
|
||||
this.label = 'XML Agent'
|
||||
|
|
@ -56,6 +57,7 @@ class XMLAgent_Agents implements INode {
|
|||
this.type = 'XMLAgent'
|
||||
this.category = 'Agents'
|
||||
this.icon = 'xmlagent.svg'
|
||||
this.badge = 'NEW'
|
||||
this.description = `Agent that is designed for LLMs that are good for reasoning/writing XML (e.g: Anthropic Claude)`
|
||||
this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)]
|
||||
this.inputs = [
|
||||
|
|
@ -121,6 +123,7 @@ class XMLAgent_Agents implements INode {
|
|||
|
||||
let res: ChainValues = {}
|
||||
let sourceDocuments: ICommonObject[] = []
|
||||
let usedTools: IUsedTool[] = []
|
||||
|
||||
if (options.socketIO && options.socketIOClientId) {
|
||||
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
|
||||
|
|
@ -129,11 +132,18 @@ class XMLAgent_Agents implements INode {
|
|||
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
|
||||
sourceDocuments = res.sourceDocuments
|
||||
}
|
||||
if (res.usedTools) {
|
||||
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
|
||||
usedTools = res.usedTools
|
||||
}
|
||||
} else {
|
||||
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
|
||||
if (res.sourceDocuments) {
|
||||
sourceDocuments = res.sourceDocuments
|
||||
}
|
||||
if (res.usedTools) {
|
||||
usedTools = res.usedTools
|
||||
}
|
||||
}
|
||||
|
||||
await memory.addChatMessages(
|
||||
|
|
@ -150,7 +160,20 @@ class XMLAgent_Agents implements INode {
|
|||
this.sessionId
|
||||
)
|
||||
|
||||
return sourceDocuments.length ? { text: res?.output, sourceDocuments: flatten(sourceDocuments) } : res?.output
|
||||
let finalRes = res?.output
|
||||
|
||||
if (sourceDocuments.length || usedTools.length) {
|
||||
finalRes = { text: res?.output }
|
||||
if (sourceDocuments.length) {
|
||||
finalRes.sourceDocuments = flatten(sourceDocuments)
|
||||
}
|
||||
if (usedTools.length) {
|
||||
finalRes.usedTools = usedTools
|
||||
}
|
||||
return finalRes
|
||||
}
|
||||
|
||||
return finalRes
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,36 @@
|
|||
import { ChatCompletionResponse, ToolCalls as MistralAIToolCalls } from '@mistralai/mistralai'
|
||||
import { BaseCache } from '@langchain/core/caches'
|
||||
import { ChatMistralAI, ChatMistralAIInput } from '@langchain/mistralai'
|
||||
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
|
||||
import { NewTokenIndices } from '@langchain/core/callbacks/base'
|
||||
import { ChatGeneration, ChatGenerationChunk, ChatResult } from '@langchain/core/outputs'
|
||||
import {
|
||||
MessageType,
|
||||
type BaseMessage,
|
||||
MessageContent,
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
AIMessageChunk,
|
||||
ToolMessageChunk,
|
||||
ChatMessageChunk
|
||||
} from '@langchain/core/messages'
|
||||
import { ChatMistralAI as LangchainChatMistralAI, ChatMistralAIInput } from '@langchain/mistralai'
|
||||
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
|
||||
import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
|
||||
|
||||
interface TokenUsage {
|
||||
completionTokens?: number
|
||||
promptTokens?: number
|
||||
totalTokens?: number
|
||||
}
|
||||
|
||||
type MistralAIInputMessage = {
|
||||
role: string
|
||||
name?: string
|
||||
content: string | string[]
|
||||
tool_calls?: MistralAIToolCalls[] | any[]
|
||||
}
|
||||
|
||||
class ChatMistral_ChatModels implements INode {
|
||||
label: string
|
||||
name: string
|
||||
|
|
@ -135,4 +163,243 @@ class ChatMistral_ChatModels implements INode {
|
|||
}
|
||||
}
|
||||
|
||||
class ChatMistralAI extends LangchainChatMistralAI {
|
||||
async _generate(
|
||||
messages: BaseMessage[],
|
||||
options?: this['ParsedCallOptions'],
|
||||
runManager?: CallbackManagerForLLMRun
|
||||
): Promise<ChatResult> {
|
||||
const tokenUsage: TokenUsage = {}
|
||||
const params = this.invocationParams(options)
|
||||
const mistralMessages = this.convertMessagesToMistralMessages(messages)
|
||||
const input = {
|
||||
...params,
|
||||
messages: mistralMessages
|
||||
}
|
||||
|
||||
// Handle streaming
|
||||
if (this.streaming) {
|
||||
const stream = this._streamResponseChunks(messages, options, runManager)
|
||||
const finalChunks: Record<number, ChatGenerationChunk> = {}
|
||||
for await (const chunk of stream) {
|
||||
const index = (chunk.generationInfo as NewTokenIndices)?.completion ?? 0
|
||||
if (finalChunks[index] === undefined) {
|
||||
finalChunks[index] = chunk
|
||||
} else {
|
||||
finalChunks[index] = finalChunks[index].concat(chunk)
|
||||
}
|
||||
}
|
||||
const generations = Object.entries(finalChunks)
|
||||
.sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10))
|
||||
.map(([_, value]) => value)
|
||||
|
||||
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } }
|
||||
}
|
||||
|
||||
// Not streaming, so we can just call the API once.
|
||||
const response = await this.completionWithRetry(input, false)
|
||||
|
||||
const { completion_tokens: completionTokens, prompt_tokens: promptTokens, total_tokens: totalTokens } = response?.usage ?? {}
|
||||
|
||||
if (completionTokens) {
|
||||
tokenUsage.completionTokens = (tokenUsage.completionTokens ?? 0) + completionTokens
|
||||
}
|
||||
|
||||
if (promptTokens) {
|
||||
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens
|
||||
}
|
||||
|
||||
if (totalTokens) {
|
||||
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens
|
||||
}
|
||||
|
||||
const generations: ChatGeneration[] = []
|
||||
for (const part of response?.choices ?? []) {
|
||||
if ('delta' in part) {
|
||||
throw new Error('Delta not supported in non-streaming mode.')
|
||||
}
|
||||
if (!('message' in part)) {
|
||||
throw new Error('No message found in the choice.')
|
||||
}
|
||||
const text = part.message?.content ?? ''
|
||||
const generation: ChatGeneration = {
|
||||
text,
|
||||
message: this.mistralAIResponseToChatMessage(part)
|
||||
}
|
||||
if (part.finish_reason) {
|
||||
generation.generationInfo = { finish_reason: part.finish_reason }
|
||||
}
|
||||
generations.push(generation)
|
||||
}
|
||||
return {
|
||||
generations,
|
||||
llmOutput: { tokenUsage }
|
||||
}
|
||||
}
|
||||
|
||||
async *_streamResponseChunks(
|
||||
messages: BaseMessage[],
|
||||
options?: this['ParsedCallOptions'],
|
||||
runManager?: CallbackManagerForLLMRun
|
||||
): AsyncGenerator<ChatGenerationChunk> {
|
||||
const mistralMessages = this.convertMessagesToMistralMessages(messages)
|
||||
const params = this.invocationParams(options)
|
||||
const input = {
|
||||
...params,
|
||||
messages: mistralMessages
|
||||
}
|
||||
|
||||
const streamIterable = await this.completionWithRetry(input, true)
|
||||
for await (const data of streamIterable) {
|
||||
const choice = data?.choices[0]
|
||||
if (!choice || !('delta' in choice)) {
|
||||
continue
|
||||
}
|
||||
|
||||
const { delta } = choice
|
||||
if (!delta) {
|
||||
continue
|
||||
}
|
||||
const newTokenIndices = {
|
||||
prompt: 0,
|
||||
completion: choice.index ?? 0
|
||||
}
|
||||
const message = this._convertDeltaToMessageChunk(delta)
|
||||
if (message === null) {
|
||||
// Do not yield a chunk if the message is empty
|
||||
continue
|
||||
}
|
||||
const generationChunk = new ChatGenerationChunk({
|
||||
message,
|
||||
text: delta.content ?? '',
|
||||
generationInfo: newTokenIndices
|
||||
})
|
||||
yield generationChunk
|
||||
|
||||
void runManager?.handleLLMNewToken(generationChunk.text ?? '', newTokenIndices, undefined, undefined, undefined, {
|
||||
chunk: generationChunk
|
||||
})
|
||||
}
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error('AbortError')
|
||||
}
|
||||
}
|
||||
|
||||
_convertDeltaToMessageChunk(delta: {
|
||||
role?: string | undefined
|
||||
content?: string | undefined
|
||||
tool_calls?: MistralAIToolCalls[] | undefined
|
||||
}) {
|
||||
if (!delta.content && !delta.tool_calls) {
|
||||
return null
|
||||
}
|
||||
// Our merge additional kwargs util function will throw unless there
|
||||
// is an index key in each tool object (as seen in OpenAI's) so we
|
||||
// need to insert it here.
|
||||
const toolCallsWithIndex = delta.tool_calls?.length
|
||||
? delta.tool_calls?.map((toolCall, index) => ({
|
||||
...toolCall,
|
||||
index
|
||||
}))
|
||||
: undefined
|
||||
|
||||
let role = 'assistant'
|
||||
if (delta.role) {
|
||||
role = delta.role
|
||||
} else if (toolCallsWithIndex) {
|
||||
role = 'tool'
|
||||
}
|
||||
const content = delta.content ?? ''
|
||||
let additional_kwargs
|
||||
if (toolCallsWithIndex) {
|
||||
additional_kwargs = {
|
||||
tool_calls: toolCallsWithIndex
|
||||
}
|
||||
} else {
|
||||
additional_kwargs = {}
|
||||
}
|
||||
|
||||
if (role === 'user') {
|
||||
return new HumanMessageChunk({ content })
|
||||
} else if (role === 'assistant') {
|
||||
return new AIMessageChunk({ content, additional_kwargs })
|
||||
} else if (role === 'tool') {
|
||||
return new ToolMessageChunk({
|
||||
content,
|
||||
additional_kwargs,
|
||||
tool_call_id: toolCallsWithIndex?.[0].id ?? ''
|
||||
})
|
||||
} else {
|
||||
return new ChatMessageChunk({ content, role })
|
||||
}
|
||||
}
|
||||
|
||||
convertMessagesToMistralMessages(messages: Array<BaseMessage>): Array<MistralAIInputMessage> {
|
||||
const getRole = (role: MessageType) => {
|
||||
switch (role) {
|
||||
case 'human':
|
||||
return 'user'
|
||||
case 'ai':
|
||||
return 'assistant'
|
||||
case 'tool':
|
||||
return 'tool'
|
||||
case 'function':
|
||||
return 'function'
|
||||
case 'system':
|
||||
return 'system'
|
||||
default:
|
||||
throw new Error(`Unknown message type: ${role}`)
|
||||
}
|
||||
}
|
||||
|
||||
const getContent = (content: MessageContent): string => {
|
||||
if (typeof content === 'string') {
|
||||
return content
|
||||
}
|
||||
throw new Error(`ChatMistralAI does not support non text message content. Received: ${JSON.stringify(content, null, 2)}`)
|
||||
}
|
||||
|
||||
const mistralMessages = []
|
||||
for (const msg of messages) {
|
||||
const msgObj: MistralAIInputMessage = {
|
||||
role: getRole(msg._getType()),
|
||||
content: getContent(msg.content)
|
||||
}
|
||||
if (getRole(msg._getType()) === 'tool') {
|
||||
msgObj.role = 'assistant'
|
||||
msgObj.tool_calls = msg.additional_kwargs?.tool_calls ?? []
|
||||
} else if (getRole(msg._getType()) === 'function') {
|
||||
msgObj.role = 'tool'
|
||||
msgObj.name = msg.name
|
||||
}
|
||||
|
||||
mistralMessages.push(msgObj)
|
||||
}
|
||||
|
||||
return mistralMessages
|
||||
}
|
||||
|
||||
mistralAIResponseToChatMessage(choice: ChatCompletionResponse['choices'][0]): BaseMessage {
|
||||
const { message } = choice
|
||||
// MistralAI SDK does not include tool_calls in the non
|
||||
// streaming return type, so we need to extract it like this
|
||||
// to satisfy typescript.
|
||||
let toolCalls: MistralAIToolCalls[] = []
|
||||
if ('tool_calls' in message) {
|
||||
toolCalls = message.tool_calls as MistralAIToolCalls[]
|
||||
}
|
||||
switch (message.role) {
|
||||
case 'assistant':
|
||||
return new AIMessage({
|
||||
content: message.content ?? '',
|
||||
additional_kwargs: {
|
||||
tool_calls: toolCalls
|
||||
}
|
||||
})
|
||||
default:
|
||||
return new HumanMessage(message.content ?? '')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { nodeClass: ChatMistral_ChatModels }
|
||||
|
|
|
|||
|
|
@ -38,9 +38,10 @@
|
|||
"@langchain/community": "^0.0.39",
|
||||
"@langchain/google-genai": "^0.0.10",
|
||||
"@langchain/groq": "^0.0.2",
|
||||
"@langchain/mistralai": "^0.0.7",
|
||||
"@langchain/mistralai": "^0.0.11",
|
||||
"@langchain/openai": "^0.0.14",
|
||||
"@langchain/pinecone": "^0.0.3",
|
||||
"@mistralai/mistralai": "0.1.3",
|
||||
"@notionhq/client": "^2.2.8",
|
||||
"@opensearch-project/opensearch": "^1.2.0",
|
||||
"@pinecone-database/pinecone": "^2.0.1",
|
||||
|
|
@ -70,7 +71,7 @@
|
|||
"ioredis": "^5.3.2",
|
||||
"jsdom": "^22.1.0",
|
||||
"jsonpointer": "^5.0.1",
|
||||
"langchain": "^0.1.20",
|
||||
"langchain": "^0.1.26",
|
||||
"langfuse": "3.3.1",
|
||||
"langfuse-langchain": "^3.3.1",
|
||||
"langsmith": "0.1.6",
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import {
|
|||
StoppingMethod
|
||||
} from 'langchain/agents'
|
||||
import { formatLogToString } from 'langchain/agents/format_scratchpad/log'
|
||||
import { IUsedTool } from './Interface'
|
||||
|
||||
export const SOURCE_DOCUMENTS_PREFIX = '\n\n----FLOWISE_SOURCE_DOCUMENTS----\n\n'
|
||||
type AgentFinish = {
|
||||
|
|
@ -341,11 +342,13 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
|
|||
const steps: AgentStep[] = []
|
||||
let iterations = 0
|
||||
let sourceDocuments: Array<Document> = []
|
||||
const usedTools: IUsedTool[] = []
|
||||
|
||||
const getOutput = async (finishStep: AgentFinish): Promise<AgentExecutorOutput> => {
|
||||
const { returnValues } = finishStep
|
||||
const additional = await this.agent.prepareForOutput(returnValues, steps)
|
||||
if (sourceDocuments.length) additional.sourceDocuments = flatten(sourceDocuments)
|
||||
if (usedTools.length) additional.usedTools = usedTools
|
||||
|
||||
if (this.returnIntermediateSteps) {
|
||||
return { ...returnValues, intermediateSteps: steps, ...additional }
|
||||
|
|
@ -410,18 +413,27 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
|
|||
* - tags?: string[]
|
||||
* - flowConfig?: { sessionId?: string, chatId?: string, input?: string }
|
||||
*/
|
||||
observation = tool
|
||||
? await (tool as any).call(
|
||||
this.isXML && typeof action.toolInput === 'string' ? { input: action.toolInput } : action.toolInput,
|
||||
runManager?.getChild(),
|
||||
undefined,
|
||||
{
|
||||
sessionId: this.sessionId,
|
||||
chatId: this.chatId,
|
||||
input: this.input
|
||||
}
|
||||
)
|
||||
: `${action.tool} is not a valid tool, try another one.`
|
||||
if (tool) {
|
||||
observation = await (tool as any).call(
|
||||
this.isXML && typeof action.toolInput === 'string' ? { input: action.toolInput } : action.toolInput,
|
||||
runManager?.getChild(),
|
||||
undefined,
|
||||
{
|
||||
sessionId: this.sessionId,
|
||||
chatId: this.chatId,
|
||||
input: this.input
|
||||
}
|
||||
)
|
||||
usedTools.push({
|
||||
tool: tool.name,
|
||||
toolInput: action.toolInput as any,
|
||||
toolOutput: observation.includes(SOURCE_DOCUMENTS_PREFIX)
|
||||
? observation.split(SOURCE_DOCUMENTS_PREFIX)[0]
|
||||
: observation
|
||||
})
|
||||
} else {
|
||||
observation = `${action.tool} is not a valid tool, try another one.`
|
||||
}
|
||||
} catch (e) {
|
||||
if (e instanceof ToolInputParsingException) {
|
||||
if (this.handleParsingErrors === true) {
|
||||
|
|
|
|||
|
|
@ -905,7 +905,13 @@ export const isFlowValidForStream = (reactFlowNodes: IReactFlowNode[], endingNod
|
|||
isValidChainOrAgent = !blacklistChains.includes(endingNodeData.name)
|
||||
} else if (endingNodeData.category === 'Agents') {
|
||||
// Agent that are available to stream
|
||||
const whitelistAgents = ['openAIFunctionAgent', 'csvAgent', 'airtableAgent', 'conversationalRetrievalAgent']
|
||||
const whitelistAgents = [
|
||||
'openAIFunctionAgent',
|
||||
'mistralAIFunctionAgent',
|
||||
'csvAgent',
|
||||
'airtableAgent',
|
||||
'conversationalRetrievalAgent'
|
||||
]
|
||||
isValidChainOrAgent = whitelistAgents.includes(endingNodeData.name)
|
||||
} else if (endingNodeData.category === 'Engine') {
|
||||
const whitelistEngine = ['contextChatEngine', 'simpleChatEngine', 'queryEngine', 'subQuestionQueryEngine']
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import {
|
|||
Typography
|
||||
} from '@mui/material'
|
||||
import { useTheme } from '@mui/material/styles'
|
||||
import { IconCircleDot, IconDownload, IconSend, IconMicrophone, IconPhotoPlus, IconTrash, IconX } from '@tabler/icons'
|
||||
import { IconCircleDot, IconDownload, IconSend, IconMicrophone, IconPhotoPlus, IconTrash, IconX, IconTool } from '@tabler/icons'
|
||||
import robotPNG from '@/assets/images/robot.png'
|
||||
import userPNG from '@/assets/images/account.png'
|
||||
import audioUploadSVG from '@/assets/images/wave-sound.jpg'
|
||||
|
|
@ -340,6 +340,15 @@ export const ChatMessage = ({ open, chatflowid, isDialog, previews, setPreviews
|
|||
})
|
||||
}
|
||||
|
||||
const updateLastMessageUsedTools = (usedTools) => {
|
||||
setMessages((prevMessages) => {
|
||||
let allMessages = [...cloneDeep(prevMessages)]
|
||||
if (allMessages[allMessages.length - 1].type === 'userMessage') return allMessages
|
||||
allMessages[allMessages.length - 1].usedTools = usedTools
|
||||
return allMessages
|
||||
})
|
||||
}
|
||||
|
||||
// Handle errors
|
||||
const handleError = (message = 'Oops! There seems to be an error. Please try again.') => {
|
||||
message = message.replace(`Unable to parse JSON response from chat agent.\n\n`, '')
|
||||
|
|
@ -596,6 +605,8 @@ export const ChatMessage = ({ open, chatflowid, isDialog, previews, setPreviews
|
|||
|
||||
socket.on('sourceDocuments', updateLastMessageSourceDocuments)
|
||||
|
||||
socket.on('usedTools', updateLastMessageUsedTools)
|
||||
|
||||
socket.on('token', updateLastMessage)
|
||||
}
|
||||
|
||||
|
|
@ -770,6 +781,7 @@ export const ChatMessage = ({ open, chatflowid, isDialog, previews, setPreviews
|
|||
sx={{ mr: 1, mt: 1 }}
|
||||
variant='outlined'
|
||||
clickable
|
||||
icon={<IconTool size={15} />}
|
||||
onClick={() => onSourceDialogClick(tool, 'Used Tools')}
|
||||
/>
|
||||
)
|
||||
|
|
|
|||
59734
pnpm-lock.yaml
59734
pnpm-lock.yaml
File diff suppressed because it is too large
Load Diff
|
|
@ -1,2 +1,2 @@
|
|||
packages:
|
||||
- 'packages/*'
|
||||
- 'packages/*'
|
||||
|
|
|
|||
Loading…
Reference in New Issue