import { LLM, BaseLLMParams } from '@langchain/core/language_models/llms' import { getEnvironmentVariable } from '../../../src/utils' export interface HFInput { /** Model to use */ model: string /** Sampling temperature to use */ temperature?: number /** * Maximum number of tokens to generate in the completion. */ maxTokens?: number /** Total probability mass of tokens to consider at each step */ topP?: number /** Integer to define the top tokens considered within the sample operation to create new text. */ topK?: number /** Penalizes repeated tokens according to frequency */ frequencyPenalty?: number /** API key to use. */ apiKey?: string /** Private endpoint to use. */ endpoint?: string } export class HuggingFaceInference extends LLM implements HFInput { get lc_secrets(): { [key: string]: string } | undefined { return { apiKey: 'HUGGINGFACEHUB_API_KEY' } } model = 'gpt2' temperature: number | undefined = undefined maxTokens: number | undefined = undefined topP: number | undefined = undefined topK: number | undefined = undefined frequencyPenalty: number | undefined = undefined apiKey: string | undefined = undefined endpoint: string | undefined = undefined constructor(fields?: Partial & BaseLLMParams) { super(fields ?? {}) this.model = fields?.model ?? this.model this.temperature = fields?.temperature ?? this.temperature this.maxTokens = fields?.maxTokens ?? this.maxTokens this.topP = fields?.topP ?? this.topP this.topK = fields?.topK ?? this.topK this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty this.endpoint = fields?.endpoint ?? '' this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY') if (!this.apiKey) { throw new Error( 'Please set an API key for HuggingFace Hub in the environment variable HUGGINGFACEHUB_API_KEY or in the apiKey field of the HuggingFaceInference constructor.' ) } } _llmType() { return 'hf' } /** @ignore */ async _call(prompt: string, options: this['ParsedCallOptions']): Promise { const { HfInference } = await HuggingFaceInference.imports() const hf = new HfInference(this.apiKey) // v4 uses Inference Providers by default; only override if custom endpoint provided const hfClient = this.endpoint ? hf.endpoint(this.endpoint) : hf const obj: any = { parameters: { // make it behave similar to openai, returning only the generated text return_full_text: false, temperature: this.temperature, max_new_tokens: this.maxTokens, top_p: this.topP, top_k: this.topK, repetition_penalty: this.frequencyPenalty }, inputs: prompt } if (!this.endpoint) { obj.model = this.model } const res = await this.caller.callWithOptions({ signal: options.signal }, hfClient.textGeneration.bind(hfClient), obj) return res.generated_text } /** @ignore */ static async imports(): Promise<{ HfInference: typeof import('@huggingface/inference').HfInference }> { try { const { HfInference } = await import('@huggingface/inference') return { HfInference } } catch (e) { throw new Error('Please install huggingface as a dependency with, e.g. `pnpm add @huggingface/inference`') } } }