import { BaseChatModel, type BaseChatModelParams } from '@langchain/core/language_models/chat_models' import { AIMessageChunk, BaseMessage } from '@langchain/core/messages' import { BaseChatModelCallOptions } from '@langchain/core/language_models/chat_models' import { NemoClient } from './NemoClient' import { CallbackManager, CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager' import { ChatResult } from '@langchain/core/outputs' import { FailedAttemptHandler } from '@langchain/core/utils/async_caller' import { getBaseClasses, INode, INodeData, INodeParams } from '../../../src' export interface ChatNemoGuardrailsCallOptions extends BaseChatModelCallOptions { /** * An array of strings to stop on. */ stop?: string[] } export interface ChatNemoGuardrailsInput extends BaseChatModelParams { configurationId?: string /** * The host URL of the Nemo server. * @default "http://localhost:8000" */ baseUrl?: string } class ChatNemoGuardrailsModel extends BaseChatModel implements ChatNemoGuardrailsInput { configurationId: string id: string baseUrl: string callbackManager?: CallbackManager | undefined maxConcurrency?: number | undefined maxRetries?: number | undefined onFailedAttempt?: FailedAttemptHandler | undefined client: NemoClient _llmType(): string { return 'nemo-guardrails' } _generate(messages: BaseMessage[], options: this['ParsedCallOptions'], runManager?: CallbackManagerForLLMRun): Promise { const generate = async (messages: BaseMessage[], client: NemoClient): Promise => { const chatMessages = await client.chat(messages) const generations = chatMessages.map((message) => { return { text: message.content?.toString() ?? '', message } }) await runManager?.handleLLMNewToken(generations.length ? generations[0].text : '') return { generations } } return generate(messages, this.client) } constructor({ id, fields }: { id: string; fields: Partial & BaseChatModelParams }) { super(fields) this.id = id this.configurationId = fields.configurationId ?? '' this.baseUrl = fields.baseUrl ?? '' this.callbackManager = fields.callbackManager this.maxConcurrency = fields.maxConcurrency this.maxRetries = fields.maxRetries this.onFailedAttempt = fields.onFailedAttempt this.client = new NemoClient(this.baseUrl, this.configurationId) } } class ChatNemoGuardrailsChatModel implements INode { label: string name: string version: number type: string icon: string category: string description: string baseClasses: string[] credential: INodeParams inputs: INodeParams[] constructor() { this.label = 'Chat Nemo Guardrails' this.name = 'chatNemoGuardrails' this.version = 1.0 this.type = 'ChatNemoGuardrails' this.icon = 'nemo.svg' this.category = 'Chat Models' this.description = 'Access models through the Nemo Guardrails API' this.baseClasses = [this.type, ...getBaseClasses(ChatNemoGuardrailsModel)] this.inputs = [ { label: 'Configuration ID', name: 'configurationId', type: 'string', optional: false }, { label: 'Base URL', name: 'baseUrl', type: 'string', optional: false } ] } async init(nodeData: INodeData): Promise { const configurationId = nodeData.inputs?.configurationId const baseUrl = nodeData.inputs?.baseUrl const obj: Partial = { configurationId: configurationId, baseUrl: baseUrl } const model = new ChatNemoGuardrailsModel({ id: nodeData.id, fields: obj }) return model } } module.exports = { nodeClass: ChatNemoGuardrailsChatModel }