Bugfix/HF custom endpoint (#2811)

include fix for hf custom endpoint
This commit is contained in:
Henry Heng 2024-07-16 21:42:24 +01:00 committed by GitHub
parent 95b2cf7b7f
commit 54ff43e8f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 76 additions and 34 deletions

View File

@ -18,7 +18,7 @@ class ChatHuggingFace_ChatModels implements INode {
constructor() { constructor() {
this.label = 'ChatHuggingFace' this.label = 'ChatHuggingFace'
this.name = 'chatHuggingFace' this.name = 'chatHuggingFace'
this.version = 2.0 this.version = 3.0
this.type = 'ChatHuggingFace' this.type = 'ChatHuggingFace'
this.icon = 'HuggingFace.svg' this.icon = 'HuggingFace.svg'
this.category = 'Chat Models' this.category = 'Chat Models'
@ -96,6 +96,16 @@ class ChatHuggingFace_ChatModels implements INode {
description: 'Frequency Penalty parameter may not apply to certain model. Please check available model parameters', description: 'Frequency Penalty parameter may not apply to certain model. Please check available model parameters',
optional: true, optional: true,
additionalParams: true additionalParams: true
},
{
label: 'Stop Sequence',
name: 'stop',
type: 'string',
rows: 4,
placeholder: 'AI assistant:',
description: 'Sets the stop sequences to use. Use comma to seperate different sequences.',
optional: true,
additionalParams: true
} }
] ]
} }
@ -109,6 +119,7 @@ class ChatHuggingFace_ChatModels implements INode {
const frequencyPenalty = nodeData.inputs?.frequencyPenalty as string const frequencyPenalty = nodeData.inputs?.frequencyPenalty as string
const endpoint = nodeData.inputs?.endpoint as string const endpoint = nodeData.inputs?.endpoint as string
const cache = nodeData.inputs?.cache as BaseCache const cache = nodeData.inputs?.cache as BaseCache
const stop = nodeData.inputs?.stop as string
const credentialData = await getCredentialData(nodeData.credential ?? '', options) const credentialData = await getCredentialData(nodeData.credential ?? '', options)
const huggingFaceApiKey = getCredentialParam('huggingFaceApiKey', credentialData, nodeData) const huggingFaceApiKey = getCredentialParam('huggingFaceApiKey', credentialData, nodeData)
@ -123,7 +134,11 @@ class ChatHuggingFace_ChatModels implements INode {
if (topP) obj.topP = parseFloat(topP) if (topP) obj.topP = parseFloat(topP)
if (hfTopK) obj.topK = parseFloat(hfTopK) if (hfTopK) obj.topK = parseFloat(hfTopK)
if (frequencyPenalty) obj.frequencyPenalty = parseFloat(frequencyPenalty) if (frequencyPenalty) obj.frequencyPenalty = parseFloat(frequencyPenalty)
if (endpoint) obj.endpoint = endpoint if (endpoint) obj.endpointUrl = endpoint
if (stop) {
const stopSequences = stop.split(',')
obj.stopSequences = stopSequences
}
const huggingFace = new HuggingFaceInference(obj) const huggingFace = new HuggingFaceInference(obj)
if (cache) huggingFace.cache = cache if (cache) huggingFace.cache = cache

View File

@ -1,32 +1,19 @@
import { LLM, BaseLLMParams } from '@langchain/core/language_models/llms' import { LLM, BaseLLMParams } from '@langchain/core/language_models/llms'
import { getEnvironmentVariable } from '../../../src/utils' import { getEnvironmentVariable } from '../../../src/utils'
import { GenerationChunk } from '@langchain/core/outputs'
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
export interface HFInput { export interface HFInput {
/** Model to use */
model: string model: string
/** Sampling temperature to use */
temperature?: number temperature?: number
/**
* Maximum number of tokens to generate in the completion.
*/
maxTokens?: number maxTokens?: number
stopSequences?: string[]
/** Total probability mass of tokens to consider at each step */
topP?: number topP?: number
/** Integer to define the top tokens considered within the sample operation to create new text. */
topK?: number topK?: number
/** Penalizes repeated tokens according to frequency */
frequencyPenalty?: number frequencyPenalty?: number
/** API key to use. */
apiKey?: string apiKey?: string
endpointUrl?: string
/** Private endpoint to use. */ includeCredentials?: string | boolean
endpoint?: string
} }
export class HuggingFaceInference extends LLM implements HFInput { export class HuggingFaceInference extends LLM implements HFInput {
@ -40,6 +27,8 @@ export class HuggingFaceInference extends LLM implements HFInput {
temperature: number | undefined = undefined temperature: number | undefined = undefined
stopSequences: string[] | undefined = undefined
maxTokens: number | undefined = undefined maxTokens: number | undefined = undefined
topP: number | undefined = undefined topP: number | undefined = undefined
@ -50,7 +39,9 @@ export class HuggingFaceInference extends LLM implements HFInput {
apiKey: string | undefined = undefined apiKey: string | undefined = undefined
endpoint: string | undefined = undefined endpointUrl: string | undefined = undefined
includeCredentials: string | boolean | undefined = undefined
constructor(fields?: Partial<HFInput> & BaseLLMParams) { constructor(fields?: Partial<HFInput> & BaseLLMParams) {
super(fields ?? {}) super(fields ?? {})
@ -58,11 +49,13 @@ export class HuggingFaceInference extends LLM implements HFInput {
this.model = fields?.model ?? this.model this.model = fields?.model ?? this.model
this.temperature = fields?.temperature ?? this.temperature this.temperature = fields?.temperature ?? this.temperature
this.maxTokens = fields?.maxTokens ?? this.maxTokens this.maxTokens = fields?.maxTokens ?? this.maxTokens
this.stopSequences = fields?.stopSequences ?? this.stopSequences
this.topP = fields?.topP ?? this.topP this.topP = fields?.topP ?? this.topP
this.topK = fields?.topK ?? this.topK this.topK = fields?.topK ?? this.topK
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty
this.endpoint = fields?.endpoint ?? ''
this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY') this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY')
this.endpointUrl = fields?.endpointUrl
this.includeCredentials = fields?.includeCredentials
if (!this.apiKey) { if (!this.apiKey) {
throw new Error( throw new Error(
'Please set an API key for HuggingFace Hub in the environment variable HUGGINGFACEHUB_API_KEY or in the apiKey field of the HuggingFaceInference constructor.' 'Please set an API key for HuggingFace Hub in the environment variable HUGGINGFACEHUB_API_KEY or in the apiKey field of the HuggingFaceInference constructor.'
@ -74,31 +67,65 @@ export class HuggingFaceInference extends LLM implements HFInput {
return 'hf' return 'hf'
} }
/** @ignore */ invocationParams(options?: this['ParsedCallOptions']) {
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> { return {
const { HfInference } = await HuggingFaceInference.imports() model: this.model,
const hf = new HfInference(this.apiKey)
const obj: any = {
parameters: { parameters: {
// make it behave similar to openai, returning only the generated text // make it behave similar to openai, returning only the generated text
return_full_text: false, return_full_text: false,
temperature: this.temperature, temperature: this.temperature,
max_new_tokens: this.maxTokens, max_new_tokens: this.maxTokens,
stop: options?.stop ?? this.stopSequences,
top_p: this.topP, top_p: this.topP,
top_k: this.topK, top_k: this.topK,
repetition_penalty: this.frequencyPenalty repetition_penalty: this.frequencyPenalty
}, }
inputs: prompt
} }
if (this.endpoint) { }
hf.endpoint(this.endpoint)
} else { async *_streamResponseChunks(
obj.model = this.model prompt: string,
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
const hfi = await this._prepareHFInference()
const stream = await this.caller.call(async () =>
hfi.textGenerationStream({
...this.invocationParams(options),
inputs: prompt
})
)
for await (const chunk of stream) {
const token = chunk.token.text
yield new GenerationChunk({ text: token, generationInfo: chunk })
await runManager?.handleLLMNewToken(token ?? '')
// stream is done
if (chunk.generated_text)
yield new GenerationChunk({
text: '',
generationInfo: { finished: true }
})
} }
const res = await this.caller.callWithOptions({ signal: options.signal }, hf.textGeneration.bind(hf), obj) }
/** @ignore */
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
const hfi = await this._prepareHFInference()
const args = { ...this.invocationParams(options), inputs: prompt }
const res = await this.caller.callWithOptions({ signal: options.signal }, hfi.textGeneration.bind(hfi), args)
return res.generated_text return res.generated_text
} }
/** @ignore */
private async _prepareHFInference() {
const { HfInference } = await HuggingFaceInference.imports()
const hfi = new HfInference(this.apiKey, {
includeCredentials: this.includeCredentials
})
return this.endpointUrl ? hfi.endpoint(this.endpointUrl) : hfi
}
/** @ignore */ /** @ignore */
static async imports(): Promise<{ static async imports(): Promise<{
HfInference: typeof import('@huggingface/inference').HfInference HfInference: typeof import('@huggingface/inference').HfInference