From 54ff43e8f16e1e65820b83100cd4afdde03dfe2f Mon Sep 17 00:00:00 2001 From: Henry Heng Date: Tue, 16 Jul 2024 21:42:24 +0100 Subject: [PATCH] Bugfix/HF custom endpoint (#2811) include fix for hf custom endpoint --- .../ChatHuggingFace/ChatHuggingFace.ts | 19 +++- .../nodes/chatmodels/ChatHuggingFace/core.ts | 91 ++++++++++++------- 2 files changed, 76 insertions(+), 34 deletions(-) diff --git a/packages/components/nodes/chatmodels/ChatHuggingFace/ChatHuggingFace.ts b/packages/components/nodes/chatmodels/ChatHuggingFace/ChatHuggingFace.ts index 2729d0157..29d1b74e5 100644 --- a/packages/components/nodes/chatmodels/ChatHuggingFace/ChatHuggingFace.ts +++ b/packages/components/nodes/chatmodels/ChatHuggingFace/ChatHuggingFace.ts @@ -18,7 +18,7 @@ class ChatHuggingFace_ChatModels implements INode { constructor() { this.label = 'ChatHuggingFace' this.name = 'chatHuggingFace' - this.version = 2.0 + this.version = 3.0 this.type = 'ChatHuggingFace' this.icon = 'HuggingFace.svg' this.category = 'Chat Models' @@ -96,6 +96,16 @@ class ChatHuggingFace_ChatModels implements INode { description: 'Frequency Penalty parameter may not apply to certain model. Please check available model parameters', optional: true, additionalParams: true + }, + { + label: 'Stop Sequence', + name: 'stop', + type: 'string', + rows: 4, + placeholder: 'AI assistant:', + description: 'Sets the stop sequences to use. Use comma to seperate different sequences.', + optional: true, + additionalParams: true } ] } @@ -109,6 +119,7 @@ class ChatHuggingFace_ChatModels implements INode { const frequencyPenalty = nodeData.inputs?.frequencyPenalty as string const endpoint = nodeData.inputs?.endpoint as string const cache = nodeData.inputs?.cache as BaseCache + const stop = nodeData.inputs?.stop as string const credentialData = await getCredentialData(nodeData.credential ?? '', options) const huggingFaceApiKey = getCredentialParam('huggingFaceApiKey', credentialData, nodeData) @@ -123,7 +134,11 @@ class ChatHuggingFace_ChatModels implements INode { if (topP) obj.topP = parseFloat(topP) if (hfTopK) obj.topK = parseFloat(hfTopK) if (frequencyPenalty) obj.frequencyPenalty = parseFloat(frequencyPenalty) - if (endpoint) obj.endpoint = endpoint + if (endpoint) obj.endpointUrl = endpoint + if (stop) { + const stopSequences = stop.split(',') + obj.stopSequences = stopSequences + } const huggingFace = new HuggingFaceInference(obj) if (cache) huggingFace.cache = cache diff --git a/packages/components/nodes/chatmodels/ChatHuggingFace/core.ts b/packages/components/nodes/chatmodels/ChatHuggingFace/core.ts index 7d26c1453..2cf2de25d 100644 --- a/packages/components/nodes/chatmodels/ChatHuggingFace/core.ts +++ b/packages/components/nodes/chatmodels/ChatHuggingFace/core.ts @@ -1,32 +1,19 @@ import { LLM, BaseLLMParams } from '@langchain/core/language_models/llms' import { getEnvironmentVariable } from '../../../src/utils' +import { GenerationChunk } from '@langchain/core/outputs' +import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager' export interface HFInput { - /** Model to use */ model: string - - /** Sampling temperature to use */ temperature?: number - - /** - * Maximum number of tokens to generate in the completion. - */ maxTokens?: number - - /** Total probability mass of tokens to consider at each step */ + stopSequences?: string[] topP?: number - - /** Integer to define the top tokens considered within the sample operation to create new text. */ topK?: number - - /** Penalizes repeated tokens according to frequency */ frequencyPenalty?: number - - /** API key to use. */ apiKey?: string - - /** Private endpoint to use. */ - endpoint?: string + endpointUrl?: string + includeCredentials?: string | boolean } export class HuggingFaceInference extends LLM implements HFInput { @@ -40,6 +27,8 @@ export class HuggingFaceInference extends LLM implements HFInput { temperature: number | undefined = undefined + stopSequences: string[] | undefined = undefined + maxTokens: number | undefined = undefined topP: number | undefined = undefined @@ -50,7 +39,9 @@ export class HuggingFaceInference extends LLM implements HFInput { apiKey: string | undefined = undefined - endpoint: string | undefined = undefined + endpointUrl: string | undefined = undefined + + includeCredentials: string | boolean | undefined = undefined constructor(fields?: Partial & BaseLLMParams) { super(fields ?? {}) @@ -58,11 +49,13 @@ export class HuggingFaceInference extends LLM implements HFInput { this.model = fields?.model ?? this.model this.temperature = fields?.temperature ?? this.temperature this.maxTokens = fields?.maxTokens ?? this.maxTokens + this.stopSequences = fields?.stopSequences ?? this.stopSequences this.topP = fields?.topP ?? this.topP this.topK = fields?.topK ?? this.topK this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty - this.endpoint = fields?.endpoint ?? '' this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY') + this.endpointUrl = fields?.endpointUrl + this.includeCredentials = fields?.includeCredentials if (!this.apiKey) { throw new Error( 'Please set an API key for HuggingFace Hub in the environment variable HUGGINGFACEHUB_API_KEY or in the apiKey field of the HuggingFaceInference constructor.' @@ -74,31 +67,65 @@ export class HuggingFaceInference extends LLM implements HFInput { return 'hf' } - /** @ignore */ - async _call(prompt: string, options: this['ParsedCallOptions']): Promise { - const { HfInference } = await HuggingFaceInference.imports() - const hf = new HfInference(this.apiKey) - const obj: any = { + invocationParams(options?: this['ParsedCallOptions']) { + return { + model: this.model, parameters: { // make it behave similar to openai, returning only the generated text return_full_text: false, temperature: this.temperature, max_new_tokens: this.maxTokens, + stop: options?.stop ?? this.stopSequences, top_p: this.topP, top_k: this.topK, repetition_penalty: this.frequencyPenalty - }, - inputs: prompt + } } - if (this.endpoint) { - hf.endpoint(this.endpoint) - } else { - obj.model = this.model + } + + async *_streamResponseChunks( + prompt: string, + options: this['ParsedCallOptions'], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + const hfi = await this._prepareHFInference() + const stream = await this.caller.call(async () => + hfi.textGenerationStream({ + ...this.invocationParams(options), + inputs: prompt + }) + ) + for await (const chunk of stream) { + const token = chunk.token.text + yield new GenerationChunk({ text: token, generationInfo: chunk }) + await runManager?.handleLLMNewToken(token ?? '') + + // stream is done + if (chunk.generated_text) + yield new GenerationChunk({ + text: '', + generationInfo: { finished: true } + }) } - const res = await this.caller.callWithOptions({ signal: options.signal }, hf.textGeneration.bind(hf), obj) + } + + /** @ignore */ + async _call(prompt: string, options: this['ParsedCallOptions']): Promise { + const hfi = await this._prepareHFInference() + const args = { ...this.invocationParams(options), inputs: prompt } + const res = await this.caller.callWithOptions({ signal: options.signal }, hfi.textGeneration.bind(hfi), args) return res.generated_text } + /** @ignore */ + private async _prepareHFInference() { + const { HfInference } = await HuggingFaceInference.imports() + const hfi = new HfInference(this.apiKey, { + includeCredentials: this.includeCredentials + }) + return this.endpointUrl ? hfi.endpoint(this.endpointUrl) : hfi + } + /** @ignore */ static async imports(): Promise<{ HfInference: typeof import('@huggingface/inference').HfInference