import { ChatCompletionResponse, ToolCalls as MistralAIToolCalls } from '@mistralai/mistralai' import { BaseCache } from '@langchain/core/caches' import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager' import { NewTokenIndices } from '@langchain/core/callbacks/base' import { ChatGeneration, ChatGenerationChunk, ChatResult } from '@langchain/core/outputs' import { MessageType, type BaseMessage, MessageContent, AIMessage, HumanMessage, HumanMessageChunk, AIMessageChunk, ToolMessageChunk, ChatMessageChunk } from '@langchain/core/messages' import { ChatMistralAI as LangchainChatMistralAI, ChatMistralAIInput } from '@langchain/mistralai' import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' interface TokenUsage { completionTokens?: number promptTokens?: number totalTokens?: number } type MistralAIInputMessage = { role: string name?: string content: string | string[] tool_calls?: MistralAIToolCalls[] | any[] } class ChatMistral_ChatModels implements INode { label: string name: string version: number type: string icon: string category: string description: string baseClasses: string[] credential: INodeParams inputs: INodeParams[] constructor() { this.label = 'ChatMistralAI' this.name = 'chatMistralAI' this.version = 2.0 this.type = 'ChatMistralAI' this.icon = 'MistralAI.svg' this.category = 'Chat Models' this.description = 'Wrapper around Mistral large language models that use the Chat endpoint' this.baseClasses = [this.type, ...getBaseClasses(ChatMistralAI)] this.credential = { label: 'Connect Credential', name: 'credential', type: 'credential', credentialNames: ['mistralAIApi'] } this.inputs = [ { label: 'Cache', name: 'cache', type: 'BaseCache', optional: true }, { label: 'Model Name', name: 'modelName', type: 'string', description: 'Refer to Model Selection for more available models', default: 'mistral-tiny' }, { label: 'Temperature', name: 'temperature', type: 'number', description: 'What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.', step: 0.1, default: 0.9, optional: true }, { label: 'Max Output Tokens', name: 'maxOutputTokens', type: 'number', description: 'The maximum number of tokens to generate in the completion.', step: 1, optional: true, additionalParams: true }, { label: 'Top Probability', name: 'topP', type: 'number', description: 'Nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.', step: 0.1, optional: true, additionalParams: true }, { label: 'Random Seed', name: 'randomSeed', type: 'number', description: 'The seed to use for random sampling. If set, different calls will generate deterministic results.', step: 1, optional: true, additionalParams: true }, { label: 'Safe Mode', name: 'safeMode', type: 'boolean', description: 'Whether to inject a safety prompt before all conversations.', optional: true, additionalParams: true }, { label: 'Override Endpoint', name: 'overrideEndpoint', type: 'string', optional: true, additionalParams: true } ] } async init(nodeData: INodeData, _: string, options: ICommonObject): Promise { const credentialData = await getCredentialData(nodeData.credential ?? '', options) const apiKey = getCredentialParam('mistralAIAPIKey', credentialData, nodeData) const temperature = nodeData.inputs?.temperature as string const modelName = nodeData.inputs?.modelName as string const maxOutputTokens = nodeData.inputs?.maxOutputTokens as string const topP = nodeData.inputs?.topP as string const safeMode = nodeData.inputs?.safeMode as boolean const randomSeed = nodeData.inputs?.safeMode as string const overrideEndpoint = nodeData.inputs?.overrideEndpoint as string const streaming = nodeData.inputs?.streaming as boolean const cache = nodeData.inputs?.cache as BaseCache const obj: ChatMistralAIInput = { apiKey: apiKey, modelName: modelName, streaming: streaming ?? true } if (maxOutputTokens) obj.maxTokens = parseInt(maxOutputTokens, 10) if (topP) obj.topP = parseFloat(topP) if (cache) obj.cache = cache if (temperature) obj.temperature = parseFloat(temperature) if (randomSeed) obj.randomSeed = parseFloat(randomSeed) if (safeMode) obj.safeMode = safeMode if (overrideEndpoint) obj.endpoint = overrideEndpoint const model = new ChatMistralAI(obj) return model } } class ChatMistralAI extends LangchainChatMistralAI { async _generate( messages: BaseMessage[], options?: this['ParsedCallOptions'], runManager?: CallbackManagerForLLMRun ): Promise { const tokenUsage: TokenUsage = {} const params = this.invocationParams(options) const mistralMessages = this.convertMessagesToMistralMessages(messages) const input = { ...params, messages: mistralMessages } // Handle streaming if (this.streaming) { const stream = this._streamResponseChunks(messages, options, runManager) const finalChunks: Record = {} for await (const chunk of stream) { const index = (chunk.generationInfo as NewTokenIndices)?.completion ?? 0 if (finalChunks[index] === undefined) { finalChunks[index] = chunk } else { finalChunks[index] = finalChunks[index].concat(chunk) } } const generations = Object.entries(finalChunks) .sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10)) .map(([_, value]) => value) return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } } } // Not streaming, so we can just call the API once. const response = await this.completionWithRetry(input, false) const { completion_tokens: completionTokens, prompt_tokens: promptTokens, total_tokens: totalTokens } = response?.usage ?? {} if (completionTokens) { tokenUsage.completionTokens = (tokenUsage.completionTokens ?? 0) + completionTokens } if (promptTokens) { tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens } if (totalTokens) { tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens } const generations: ChatGeneration[] = [] for (const part of response?.choices ?? []) { if ('delta' in part) { throw new Error('Delta not supported in non-streaming mode.') } if (!('message' in part)) { throw new Error('No message found in the choice.') } const text = part.message?.content ?? '' const generation: ChatGeneration = { text, message: this.mistralAIResponseToChatMessage(part) } if (part.finish_reason) { generation.generationInfo = { finish_reason: part.finish_reason } } generations.push(generation) } return { generations, llmOutput: { tokenUsage } } } async *_streamResponseChunks( messages: BaseMessage[], options?: this['ParsedCallOptions'], runManager?: CallbackManagerForLLMRun ): AsyncGenerator { const mistralMessages = this.convertMessagesToMistralMessages(messages) const params = this.invocationParams(options) const input = { ...params, messages: mistralMessages } const streamIterable = await this.completionWithRetry(input, true) for await (const data of streamIterable) { const choice = data?.choices[0] if (!choice || !('delta' in choice)) { continue } const { delta } = choice if (!delta) { continue } const newTokenIndices = { prompt: 0, completion: choice.index ?? 0 } const message = this._convertDeltaToMessageChunk(delta) if (message === null) { // Do not yield a chunk if the message is empty continue } const generationChunk = new ChatGenerationChunk({ message, text: delta.content ?? '', generationInfo: newTokenIndices }) yield generationChunk void runManager?.handleLLMNewToken(generationChunk.text ?? '', newTokenIndices, undefined, undefined, undefined, { chunk: generationChunk }) } if (options?.signal?.aborted) { throw new Error('AbortError') } } _convertDeltaToMessageChunk(delta: { role?: string | undefined content?: string | undefined tool_calls?: MistralAIToolCalls[] | undefined }) { if (!delta.content && !delta.tool_calls) { return null } // Our merge additional kwargs util function will throw unless there // is an index key in each tool object (as seen in OpenAI's) so we // need to insert it here. const toolCallsWithIndex = delta.tool_calls?.length ? delta.tool_calls?.map((toolCall, index) => ({ ...toolCall, index })) : undefined let role = 'assistant' if (delta.role) { role = delta.role } else if (toolCallsWithIndex) { role = 'tool' } const content = delta.content ?? '' let additional_kwargs if (toolCallsWithIndex) { additional_kwargs = { tool_calls: toolCallsWithIndex } } else { additional_kwargs = {} } if (role === 'user') { return new HumanMessageChunk({ content }) } else if (role === 'assistant') { return new AIMessageChunk({ content, additional_kwargs }) } else if (role === 'tool') { return new ToolMessageChunk({ content, additional_kwargs, tool_call_id: toolCallsWithIndex?.[0].id ?? '' }) } else { return new ChatMessageChunk({ content, role }) } } convertMessagesToMistralMessages(messages: Array): Array { const getRole = (role: MessageType) => { switch (role) { case 'human': return 'user' case 'ai': return 'assistant' case 'tool': return 'tool' case 'function': return 'function' case 'system': return 'system' default: throw new Error(`Unknown message type: ${role}`) } } const getContent = (content: MessageContent): string => { if (typeof content === 'string') { return content } throw new Error(`ChatMistralAI does not support non text message content. Received: ${JSON.stringify(content, null, 2)}`) } const mistralMessages = [] for (const msg of messages) { const msgObj: MistralAIInputMessage = { role: getRole(msg._getType()), content: getContent(msg.content) } if (getRole(msg._getType()) === 'tool') { msgObj.role = 'assistant' msgObj.tool_calls = msg.additional_kwargs?.tool_calls ?? [] } else if (getRole(msg._getType()) === 'function') { msgObj.role = 'tool' msgObj.name = msg.name } mistralMessages.push(msgObj) } return mistralMessages } mistralAIResponseToChatMessage(choice: ChatCompletionResponse['choices'][0]): BaseMessage { const { message } = choice // MistralAI SDK does not include tool_calls in the non // streaming return type, so we need to extract it like this // to satisfy typescript. let toolCalls: MistralAIToolCalls[] = [] if ('tool_calls' in message) { toolCalls = message.tool_calls as MistralAIToolCalls[] } switch (message.role) { case 'assistant': return new AIMessage({ content: message.content ?? '', additional_kwargs: { tool_calls: toolCalls } }) default: return new HumanMessage(message.content ?? '') } } } module.exports = { nodeClass: ChatMistral_ChatModels }