Flowise/packages/components/nodes/chatmodels/ChatHuggingFace/core.ts

141 lines
4.8 KiB
TypeScript

import { LLM, BaseLLMParams } from '@langchain/core/language_models/llms'
import { getEnvironmentVariable } from '../../../src/utils'
import { GenerationChunk } from '@langchain/core/outputs'
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
export interface HFInput {
model: string
temperature?: number
maxTokens?: number
stopSequences?: string[]
topP?: number
topK?: number
frequencyPenalty?: number
apiKey?: string
endpointUrl?: string
includeCredentials?: string | boolean
}
export class HuggingFaceInference extends LLM implements HFInput {
get lc_secrets(): { [key: string]: string } | undefined {
return {
apiKey: 'HUGGINGFACEHUB_API_KEY'
}
}
model = 'gpt2'
temperature: number | undefined = undefined
stopSequences: string[] | undefined = undefined
maxTokens: number | undefined = undefined
topP: number | undefined = undefined
topK: number | undefined = undefined
frequencyPenalty: number | undefined = undefined
apiKey: string | undefined = undefined
endpointUrl: string | undefined = undefined
includeCredentials: string | boolean | undefined = undefined
constructor(fields?: Partial<HFInput> & BaseLLMParams) {
super(fields ?? {})
this.model = fields?.model ?? this.model
this.temperature = fields?.temperature ?? this.temperature
this.maxTokens = fields?.maxTokens ?? this.maxTokens
this.stopSequences = fields?.stopSequences ?? this.stopSequences
this.topP = fields?.topP ?? this.topP
this.topK = fields?.topK ?? this.topK
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty
this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY')
this.endpointUrl = fields?.endpointUrl
this.includeCredentials = fields?.includeCredentials
if (!this.apiKey) {
throw new Error(
'Please set an API key for HuggingFace Hub in the environment variable HUGGINGFACEHUB_API_KEY or in the apiKey field of the HuggingFaceInference constructor.'
)
}
}
_llmType() {
return 'hf'
}
invocationParams(options?: this['ParsedCallOptions']) {
return {
model: this.model,
parameters: {
// make it behave similar to openai, returning only the generated text
return_full_text: false,
temperature: this.temperature,
max_new_tokens: this.maxTokens,
stop: options?.stop ?? this.stopSequences,
top_p: this.topP,
top_k: this.topK,
repetition_penalty: this.frequencyPenalty
}
}
}
async *_streamResponseChunks(
prompt: string,
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
const hfi = await this._prepareHFInference()
const stream = await this.caller.call(async () =>
hfi.textGenerationStream({
...this.invocationParams(options),
inputs: prompt
})
)
for await (const chunk of stream) {
const token = chunk.token.text
yield new GenerationChunk({ text: token, generationInfo: chunk })
await runManager?.handleLLMNewToken(token ?? '')
// stream is done
if (chunk.generated_text)
yield new GenerationChunk({
text: '',
generationInfo: { finished: true }
})
}
}
/** @ignore */
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
const hfi = await this._prepareHFInference()
const args = { ...this.invocationParams(options), inputs: prompt }
const res = await this.caller.callWithOptions({ signal: options.signal }, hfi.textGeneration.bind(hfi), args)
return res.generated_text
}
/** @ignore */
private async _prepareHFInference() {
const { HfInference } = await HuggingFaceInference.imports()
const hfi = new HfInference(this.apiKey, {
includeCredentials: this.includeCredentials
})
return this.endpointUrl ? hfi.endpoint(this.endpointUrl) : hfi
}
/** @ignore */
static async imports(): Promise<{
HfInference: typeof import('@huggingface/inference').HfInference
}> {
try {
const { HfInference } = await import('@huggingface/inference')
return { HfInference }
} catch (e) {
throw new Error('Please install huggingface as a dependency with, e.g. `pnpm install @huggingface/inference`')
}
}
}