diff --git a/packages/components/nodes/chatmodels/ChatNemoGuardrails/ChatNemoGuardrails.ts b/packages/components/nodes/chatmodels/ChatNemoGuardrails/ChatNemoGuardrails.ts new file mode 100644 index 000000000..8da1a7fbb --- /dev/null +++ b/packages/components/nodes/chatmodels/ChatNemoGuardrails/ChatNemoGuardrails.ts @@ -0,0 +1,121 @@ +import { BaseChatModel, type BaseChatModelParams } from '@langchain/core/language_models/chat_models' +import { AIMessageChunk, BaseMessage } from '@langchain/core/messages' +import { BaseChatModelCallOptions } from '@langchain/core/language_models/chat_models' +import { NemoClient } from './NemoClient' +import { CallbackManager, CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager' +import { ChatResult } from '@langchain/core/outputs' +import { FailedAttemptHandler } from '@langchain/core/utils/async_caller' +import { getBaseClasses, INode, INodeData, INodeParams } from '../../../src' + +export interface ChatNemoGuardrailsCallOptions extends BaseChatModelCallOptions { + /** + * An array of strings to stop on. + */ + stop?: string[] +} + +export interface ChatNemoGuardrailsInput extends BaseChatModelParams { + configurationId?: string + /** + * The host URL of the Nemo server. + * @default "http://localhost:8000" + */ + baseUrl?: string +} + +class ChatNemoGuardrailsModel extends BaseChatModel implements ChatNemoGuardrailsInput { + configurationId: string + id: string + baseUrl: string + callbackManager?: CallbackManager | undefined + maxConcurrency?: number | undefined + maxRetries?: number | undefined + onFailedAttempt?: FailedAttemptHandler | undefined + client: NemoClient + + _llmType(): string { + return 'nemo-guardrails' + } + + _generate(messages: BaseMessage[], options: this['ParsedCallOptions'], runManager?: CallbackManagerForLLMRun): Promise { + const generate = async (messages: BaseMessage[], client: NemoClient): Promise => { + const chatMessages = await client.chat(messages) + const generations = chatMessages.map((message) => { + return { + text: message.content?.toString() ?? '', + message + } + }) + + await runManager?.handleLLMNewToken(generations.length ? generations[0].text : '') + + return { + generations + } + } + return generate(messages, this.client) + } + + constructor({ id, fields }: { id: string; fields: Partial & BaseChatModelParams }) { + super(fields) + this.id = id + this.configurationId = fields.configurationId ?? '' + this.baseUrl = fields.baseUrl ?? '' + this.callbackManager = fields.callbackManager + this.maxConcurrency = fields.maxConcurrency + this.maxRetries = fields.maxRetries + this.onFailedAttempt = fields.onFailedAttempt + this.client = new NemoClient(this.baseUrl, this.configurationId) + } +} + +class ChatNemoGuardrailsChatModel implements INode { + label: string + name: string + version: number + type: string + icon: string + category: string + description: string + baseClasses: string[] + credential: INodeParams + inputs: INodeParams[] + + constructor() { + this.label = 'Chat Nemo Guardrails' + this.name = 'chatNemoGuardrails' + this.version = 1.0 + this.type = 'ChatNemoGuardrails' + this.icon = 'nemo.svg' + this.category = 'Chat Models' + this.description = 'Access models through the Nemo Guardrails API' + this.baseClasses = [this.type, ...getBaseClasses(ChatNemoGuardrailsModel)] + this.inputs = [ + { + label: 'Configuration ID', + name: 'configurationId', + type: 'string', + optional: false + }, + { + label: 'Base URL', + name: 'baseUrl', + type: 'string', + optional: false + } + ] + } + + async init(nodeData: INodeData): Promise { + const configurationId = nodeData.inputs?.configurationId + const baseUrl = nodeData.inputs?.baseUrl + const obj: Partial = { + configurationId: configurationId, + baseUrl: baseUrl + } + const model = new ChatNemoGuardrailsModel({ id: nodeData.id, fields: obj }) + return model + } +} + +module.exports = { nodeClass: ChatNemoGuardrailsChatModel } diff --git a/packages/components/nodes/chatmodels/ChatNemoGuardrails/NemoClient.ts b/packages/components/nodes/chatmodels/ChatNemoGuardrails/NemoClient.ts new file mode 100644 index 000000000..da14f9b40 --- /dev/null +++ b/packages/components/nodes/chatmodels/ChatNemoGuardrails/NemoClient.ts @@ -0,0 +1,70 @@ +import { AIMessage, BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages' + +export interface Config { + baseUrl: string + configurationId: string +} + +export class ClientConfig implements Config { + baseUrl: string + configurationId: string + + constructor(baseUrl: string, configurationId: string) { + this.baseUrl = baseUrl + this.configurationId = configurationId + } +} + +export class NemoClient { + private readonly config: Config + + constructor(baseUrl: string, configurationId: string) { + this.config = new ClientConfig(baseUrl, configurationId) + } + + getRoleFromMessage(message: BaseMessage): string { + if (message instanceof HumanMessage || message instanceof SystemMessage) { + return 'user' + } + + //AIMessage, ToolMessage, FunctionMessage + return 'assistant' + } + + getContentFromMessage(message: BaseMessage): string { + return message.content.toString() + } + + buildBody(messages: BaseMessage[], configurationId: string): any { + const bodyMessages = messages.map((message) => { + return { + role: this.getRoleFromMessage(message), + content: this.getContentFromMessage(message) + } + }) + + const body = { + config_id: configurationId, + messages: bodyMessages + } + + return body + } + + async chat(messages: BaseMessage[]): Promise { + const headers = new Headers() + headers.append('Content-Type', 'application/json') + + const body = this.buildBody(messages, this.config.configurationId) + + const requestOptions = { + method: 'POST', + body: JSON.stringify(body), + headers: headers + } + + return await fetch(`${this.config.baseUrl}/v1/chat/completions`, requestOptions) + .then((response) => response.json()) + .then((body) => body.messages.map((message: any) => new AIMessage(message.content))) + } +} diff --git a/packages/components/nodes/chatmodels/ChatNemoGuardrails/nemo.svg b/packages/components/nodes/chatmodels/ChatNemoGuardrails/nemo.svg new file mode 100644 index 000000000..76a39e2e6 --- /dev/null +++ b/packages/components/nodes/chatmodels/ChatNemoGuardrails/nemo.svg @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/packages/components/nodes/chatmodels/ChatNemoGuardrails/readme.md b/packages/components/nodes/chatmodels/ChatNemoGuardrails/readme.md new file mode 100644 index 000000000..44d1e5d7f --- /dev/null +++ b/packages/components/nodes/chatmodels/ChatNemoGuardrails/readme.md @@ -0,0 +1,18 @@ +Parameters: + +config_id +baseUrl + +``` +/v1/chat/completions +``` + +```json +{ + "config_id": "bedrock", + "messages": [{ + "role":"user", + "content":"Hello! What can you do for me?" + }] +} +``` \ No newline at end of file