Added NeMo Guardrails Chat component (#3331)
* added: nemoguardrails chat component * fix: removed options argument from init fix: generation text has been updated with content string * fix: lint error fixed * fix: error when get content * fix: lint errors * updated: added handleLLMNewToken for ChatNemoGuardrails --------- Co-authored-by: patrick <patrick.alves@br.experian.com>
This commit is contained in:
parent
d1adc4fb1c
commit
5117948ccf
|
|
@ -0,0 +1,121 @@
|
|||
import { BaseChatModel, type BaseChatModelParams } from '@langchain/core/language_models/chat_models'
|
||||
import { AIMessageChunk, BaseMessage } from '@langchain/core/messages'
|
||||
import { BaseChatModelCallOptions } from '@langchain/core/language_models/chat_models'
|
||||
import { NemoClient } from './NemoClient'
|
||||
import { CallbackManager, CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
|
||||
import { ChatResult } from '@langchain/core/outputs'
|
||||
import { FailedAttemptHandler } from '@langchain/core/utils/async_caller'
|
||||
import { getBaseClasses, INode, INodeData, INodeParams } from '../../../src'
|
||||
|
||||
export interface ChatNemoGuardrailsCallOptions extends BaseChatModelCallOptions {
|
||||
/**
|
||||
* An array of strings to stop on.
|
||||
*/
|
||||
stop?: string[]
|
||||
}
|
||||
|
||||
export interface ChatNemoGuardrailsInput extends BaseChatModelParams {
|
||||
configurationId?: string
|
||||
/**
|
||||
* The host URL of the Nemo server.
|
||||
* @default "http://localhost:8000"
|
||||
*/
|
||||
baseUrl?: string
|
||||
}
|
||||
|
||||
class ChatNemoGuardrailsModel extends BaseChatModel<ChatNemoGuardrailsCallOptions, AIMessageChunk> implements ChatNemoGuardrailsInput {
|
||||
configurationId: string
|
||||
id: string
|
||||
baseUrl: string
|
||||
callbackManager?: CallbackManager | undefined
|
||||
maxConcurrency?: number | undefined
|
||||
maxRetries?: number | undefined
|
||||
onFailedAttempt?: FailedAttemptHandler | undefined
|
||||
client: NemoClient
|
||||
|
||||
_llmType(): string {
|
||||
return 'nemo-guardrails'
|
||||
}
|
||||
|
||||
_generate(messages: BaseMessage[], options: this['ParsedCallOptions'], runManager?: CallbackManagerForLLMRun): Promise<ChatResult> {
|
||||
const generate = async (messages: BaseMessage[], client: NemoClient): Promise<ChatResult> => {
|
||||
const chatMessages = await client.chat(messages)
|
||||
const generations = chatMessages.map((message) => {
|
||||
return {
|
||||
text: message.content?.toString() ?? '',
|
||||
message
|
||||
}
|
||||
})
|
||||
|
||||
await runManager?.handleLLMNewToken(generations.length ? generations[0].text : '')
|
||||
|
||||
return {
|
||||
generations
|
||||
}
|
||||
}
|
||||
return generate(messages, this.client)
|
||||
}
|
||||
|
||||
constructor({ id, fields }: { id: string; fields: Partial<ChatNemoGuardrailsInput> & BaseChatModelParams }) {
|
||||
super(fields)
|
||||
this.id = id
|
||||
this.configurationId = fields.configurationId ?? ''
|
||||
this.baseUrl = fields.baseUrl ?? ''
|
||||
this.callbackManager = fields.callbackManager
|
||||
this.maxConcurrency = fields.maxConcurrency
|
||||
this.maxRetries = fields.maxRetries
|
||||
this.onFailedAttempt = fields.onFailedAttempt
|
||||
this.client = new NemoClient(this.baseUrl, this.configurationId)
|
||||
}
|
||||
}
|
||||
|
||||
class ChatNemoGuardrailsChatModel implements INode {
|
||||
label: string
|
||||
name: string
|
||||
version: number
|
||||
type: string
|
||||
icon: string
|
||||
category: string
|
||||
description: string
|
||||
baseClasses: string[]
|
||||
credential: INodeParams
|
||||
inputs: INodeParams[]
|
||||
|
||||
constructor() {
|
||||
this.label = 'Chat Nemo Guardrails'
|
||||
this.name = 'chatNemoGuardrails'
|
||||
this.version = 1.0
|
||||
this.type = 'ChatNemoGuardrails'
|
||||
this.icon = 'nemo.svg'
|
||||
this.category = 'Chat Models'
|
||||
this.description = 'Access models through the Nemo Guardrails API'
|
||||
this.baseClasses = [this.type, ...getBaseClasses(ChatNemoGuardrailsModel)]
|
||||
this.inputs = [
|
||||
{
|
||||
label: 'Configuration ID',
|
||||
name: 'configurationId',
|
||||
type: 'string',
|
||||
optional: false
|
||||
},
|
||||
{
|
||||
label: 'Base URL',
|
||||
name: 'baseUrl',
|
||||
type: 'string',
|
||||
optional: false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async init(nodeData: INodeData): Promise<any> {
|
||||
const configurationId = nodeData.inputs?.configurationId
|
||||
const baseUrl = nodeData.inputs?.baseUrl
|
||||
const obj: Partial<ChatNemoGuardrailsInput> = {
|
||||
configurationId: configurationId,
|
||||
baseUrl: baseUrl
|
||||
}
|
||||
const model = new ChatNemoGuardrailsModel({ id: nodeData.id, fields: obj })
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { nodeClass: ChatNemoGuardrailsChatModel }
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
import { AIMessage, BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages'
|
||||
|
||||
export interface Config {
|
||||
baseUrl: string
|
||||
configurationId: string
|
||||
}
|
||||
|
||||
export class ClientConfig implements Config {
|
||||
baseUrl: string
|
||||
configurationId: string
|
||||
|
||||
constructor(baseUrl: string, configurationId: string) {
|
||||
this.baseUrl = baseUrl
|
||||
this.configurationId = configurationId
|
||||
}
|
||||
}
|
||||
|
||||
export class NemoClient {
|
||||
private readonly config: Config
|
||||
|
||||
constructor(baseUrl: string, configurationId: string) {
|
||||
this.config = new ClientConfig(baseUrl, configurationId)
|
||||
}
|
||||
|
||||
getRoleFromMessage(message: BaseMessage): string {
|
||||
if (message instanceof HumanMessage || message instanceof SystemMessage) {
|
||||
return 'user'
|
||||
}
|
||||
|
||||
//AIMessage, ToolMessage, FunctionMessage
|
||||
return 'assistant'
|
||||
}
|
||||
|
||||
getContentFromMessage(message: BaseMessage): string {
|
||||
return message.content.toString()
|
||||
}
|
||||
|
||||
buildBody(messages: BaseMessage[], configurationId: string): any {
|
||||
const bodyMessages = messages.map((message) => {
|
||||
return {
|
||||
role: this.getRoleFromMessage(message),
|
||||
content: this.getContentFromMessage(message)
|
||||
}
|
||||
})
|
||||
|
||||
const body = {
|
||||
config_id: configurationId,
|
||||
messages: bodyMessages
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
async chat(messages: BaseMessage[]): Promise<AIMessage[]> {
|
||||
const headers = new Headers()
|
||||
headers.append('Content-Type', 'application/json')
|
||||
|
||||
const body = this.buildBody(messages, this.config.configurationId)
|
||||
|
||||
const requestOptions = {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(body),
|
||||
headers: headers
|
||||
}
|
||||
|
||||
return await fetch(`${this.config.baseUrl}/v1/chat/completions`, requestOptions)
|
||||
.then((response) => response.json())
|
||||
.then((body) => body.messages.map((message: any) => new AIMessage(message.content)))
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 1024 1024" xmlns="http://www.w3.org/2000/svg">
|
||||
<circle cx="512.25" cy="512.25" r="512" style="fill:#76b900"/>
|
||||
<path d="M430.58 436.9v-33.66c3.26-.23 6.56-.41 9.92-.51 92-2.89 152.41 79.09 152.41 79.09S527.7 572.4 457.78 572.4a84.86 84.86 0 0 1-27.2-4.35V466c35.83 4.33 43 20.15 64.57 56.06l47.91-40.4s-35-45.87-93.91-45.87a170.07 170.07 0 0 0-18.56 1.1m0-111.17V376c3.3-.27 6.61-.47 9.92-.59 128-4.31 211.37 105 211.37 105s-95.79 116.42-195.57 116.42a146.92 146.92 0 0 1-25.74-2.27v31.08A169.64 169.64 0 0 0 452 627c92.85 0 160-47.42 225-103.54 10.77 8.64 54.9 29.63 64 38.83-61.83 51.76-205.91 93.48-287.59 93.48-7.87 0-15.44-.47-22.86-1.19v43.67h352.93V325.73zm0 242.31v26.52C344.69 579.26 320.85 490 320.85 490s41.24-45.69 109.73-53.08V466h-.14c-35.93-4.31-64 29.25-64 29.25s15.74 56.53 64.16 72.8M278 486.11s50.9-75.11 152.54-82.87V376C318 385 220.52 480.36 220.52 480.36S275.73 640 430.58 654.6v-29C316.95 611.34 278 486.11 278 486.11" style="fill:#fff"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.1 KiB |
|
|
@ -0,0 +1,18 @@
|
|||
Parameters:
|
||||
|
||||
config_id
|
||||
baseUrl
|
||||
|
||||
```
|
||||
/v1/chat/completions
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"config_id": "bedrock",
|
||||
"messages": [{
|
||||
"role":"user",
|
||||
"content":"Hello! What can you do for me?"
|
||||
}]
|
||||
}
|
||||
```
|
||||
Loading…
Reference in New Issue