parent
95b2cf7b7f
commit
54ff43e8f1
|
|
@ -18,7 +18,7 @@ class ChatHuggingFace_ChatModels implements INode {
|
||||||
constructor() {
|
constructor() {
|
||||||
this.label = 'ChatHuggingFace'
|
this.label = 'ChatHuggingFace'
|
||||||
this.name = 'chatHuggingFace'
|
this.name = 'chatHuggingFace'
|
||||||
this.version = 2.0
|
this.version = 3.0
|
||||||
this.type = 'ChatHuggingFace'
|
this.type = 'ChatHuggingFace'
|
||||||
this.icon = 'HuggingFace.svg'
|
this.icon = 'HuggingFace.svg'
|
||||||
this.category = 'Chat Models'
|
this.category = 'Chat Models'
|
||||||
|
|
@ -96,6 +96,16 @@ class ChatHuggingFace_ChatModels implements INode {
|
||||||
description: 'Frequency Penalty parameter may not apply to certain model. Please check available model parameters',
|
description: 'Frequency Penalty parameter may not apply to certain model. Please check available model parameters',
|
||||||
optional: true,
|
optional: true,
|
||||||
additionalParams: true
|
additionalParams: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Stop Sequence',
|
||||||
|
name: 'stop',
|
||||||
|
type: 'string',
|
||||||
|
rows: 4,
|
||||||
|
placeholder: 'AI assistant:',
|
||||||
|
description: 'Sets the stop sequences to use. Use comma to seperate different sequences.',
|
||||||
|
optional: true,
|
||||||
|
additionalParams: true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
@ -109,6 +119,7 @@ class ChatHuggingFace_ChatModels implements INode {
|
||||||
const frequencyPenalty = nodeData.inputs?.frequencyPenalty as string
|
const frequencyPenalty = nodeData.inputs?.frequencyPenalty as string
|
||||||
const endpoint = nodeData.inputs?.endpoint as string
|
const endpoint = nodeData.inputs?.endpoint as string
|
||||||
const cache = nodeData.inputs?.cache as BaseCache
|
const cache = nodeData.inputs?.cache as BaseCache
|
||||||
|
const stop = nodeData.inputs?.stop as string
|
||||||
|
|
||||||
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
|
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
|
||||||
const huggingFaceApiKey = getCredentialParam('huggingFaceApiKey', credentialData, nodeData)
|
const huggingFaceApiKey = getCredentialParam('huggingFaceApiKey', credentialData, nodeData)
|
||||||
|
|
@ -123,7 +134,11 @@ class ChatHuggingFace_ChatModels implements INode {
|
||||||
if (topP) obj.topP = parseFloat(topP)
|
if (topP) obj.topP = parseFloat(topP)
|
||||||
if (hfTopK) obj.topK = parseFloat(hfTopK)
|
if (hfTopK) obj.topK = parseFloat(hfTopK)
|
||||||
if (frequencyPenalty) obj.frequencyPenalty = parseFloat(frequencyPenalty)
|
if (frequencyPenalty) obj.frequencyPenalty = parseFloat(frequencyPenalty)
|
||||||
if (endpoint) obj.endpoint = endpoint
|
if (endpoint) obj.endpointUrl = endpoint
|
||||||
|
if (stop) {
|
||||||
|
const stopSequences = stop.split(',')
|
||||||
|
obj.stopSequences = stopSequences
|
||||||
|
}
|
||||||
|
|
||||||
const huggingFace = new HuggingFaceInference(obj)
|
const huggingFace = new HuggingFaceInference(obj)
|
||||||
if (cache) huggingFace.cache = cache
|
if (cache) huggingFace.cache = cache
|
||||||
|
|
|
||||||
|
|
@ -1,32 +1,19 @@
|
||||||
import { LLM, BaseLLMParams } from '@langchain/core/language_models/llms'
|
import { LLM, BaseLLMParams } from '@langchain/core/language_models/llms'
|
||||||
import { getEnvironmentVariable } from '../../../src/utils'
|
import { getEnvironmentVariable } from '../../../src/utils'
|
||||||
|
import { GenerationChunk } from '@langchain/core/outputs'
|
||||||
|
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
|
||||||
|
|
||||||
export interface HFInput {
|
export interface HFInput {
|
||||||
/** Model to use */
|
|
||||||
model: string
|
model: string
|
||||||
|
|
||||||
/** Sampling temperature to use */
|
|
||||||
temperature?: number
|
temperature?: number
|
||||||
|
|
||||||
/**
|
|
||||||
* Maximum number of tokens to generate in the completion.
|
|
||||||
*/
|
|
||||||
maxTokens?: number
|
maxTokens?: number
|
||||||
|
stopSequences?: string[]
|
||||||
/** Total probability mass of tokens to consider at each step */
|
|
||||||
topP?: number
|
topP?: number
|
||||||
|
|
||||||
/** Integer to define the top tokens considered within the sample operation to create new text. */
|
|
||||||
topK?: number
|
topK?: number
|
||||||
|
|
||||||
/** Penalizes repeated tokens according to frequency */
|
|
||||||
frequencyPenalty?: number
|
frequencyPenalty?: number
|
||||||
|
|
||||||
/** API key to use. */
|
|
||||||
apiKey?: string
|
apiKey?: string
|
||||||
|
endpointUrl?: string
|
||||||
/** Private endpoint to use. */
|
includeCredentials?: string | boolean
|
||||||
endpoint?: string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export class HuggingFaceInference extends LLM implements HFInput {
|
export class HuggingFaceInference extends LLM implements HFInput {
|
||||||
|
|
@ -40,6 +27,8 @@ export class HuggingFaceInference extends LLM implements HFInput {
|
||||||
|
|
||||||
temperature: number | undefined = undefined
|
temperature: number | undefined = undefined
|
||||||
|
|
||||||
|
stopSequences: string[] | undefined = undefined
|
||||||
|
|
||||||
maxTokens: number | undefined = undefined
|
maxTokens: number | undefined = undefined
|
||||||
|
|
||||||
topP: number | undefined = undefined
|
topP: number | undefined = undefined
|
||||||
|
|
@ -50,7 +39,9 @@ export class HuggingFaceInference extends LLM implements HFInput {
|
||||||
|
|
||||||
apiKey: string | undefined = undefined
|
apiKey: string | undefined = undefined
|
||||||
|
|
||||||
endpoint: string | undefined = undefined
|
endpointUrl: string | undefined = undefined
|
||||||
|
|
||||||
|
includeCredentials: string | boolean | undefined = undefined
|
||||||
|
|
||||||
constructor(fields?: Partial<HFInput> & BaseLLMParams) {
|
constructor(fields?: Partial<HFInput> & BaseLLMParams) {
|
||||||
super(fields ?? {})
|
super(fields ?? {})
|
||||||
|
|
@ -58,11 +49,13 @@ export class HuggingFaceInference extends LLM implements HFInput {
|
||||||
this.model = fields?.model ?? this.model
|
this.model = fields?.model ?? this.model
|
||||||
this.temperature = fields?.temperature ?? this.temperature
|
this.temperature = fields?.temperature ?? this.temperature
|
||||||
this.maxTokens = fields?.maxTokens ?? this.maxTokens
|
this.maxTokens = fields?.maxTokens ?? this.maxTokens
|
||||||
|
this.stopSequences = fields?.stopSequences ?? this.stopSequences
|
||||||
this.topP = fields?.topP ?? this.topP
|
this.topP = fields?.topP ?? this.topP
|
||||||
this.topK = fields?.topK ?? this.topK
|
this.topK = fields?.topK ?? this.topK
|
||||||
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty
|
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty
|
||||||
this.endpoint = fields?.endpoint ?? ''
|
|
||||||
this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY')
|
this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY')
|
||||||
|
this.endpointUrl = fields?.endpointUrl
|
||||||
|
this.includeCredentials = fields?.includeCredentials
|
||||||
if (!this.apiKey) {
|
if (!this.apiKey) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
'Please set an API key for HuggingFace Hub in the environment variable HUGGINGFACEHUB_API_KEY or in the apiKey field of the HuggingFaceInference constructor.'
|
'Please set an API key for HuggingFace Hub in the environment variable HUGGINGFACEHUB_API_KEY or in the apiKey field of the HuggingFaceInference constructor.'
|
||||||
|
|
@ -74,31 +67,65 @@ export class HuggingFaceInference extends LLM implements HFInput {
|
||||||
return 'hf'
|
return 'hf'
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @ignore */
|
invocationParams(options?: this['ParsedCallOptions']) {
|
||||||
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
|
return {
|
||||||
const { HfInference } = await HuggingFaceInference.imports()
|
model: this.model,
|
||||||
const hf = new HfInference(this.apiKey)
|
|
||||||
const obj: any = {
|
|
||||||
parameters: {
|
parameters: {
|
||||||
// make it behave similar to openai, returning only the generated text
|
// make it behave similar to openai, returning only the generated text
|
||||||
return_full_text: false,
|
return_full_text: false,
|
||||||
temperature: this.temperature,
|
temperature: this.temperature,
|
||||||
max_new_tokens: this.maxTokens,
|
max_new_tokens: this.maxTokens,
|
||||||
|
stop: options?.stop ?? this.stopSequences,
|
||||||
top_p: this.topP,
|
top_p: this.topP,
|
||||||
top_k: this.topK,
|
top_k: this.topK,
|
||||||
repetition_penalty: this.frequencyPenalty
|
repetition_penalty: this.frequencyPenalty
|
||||||
},
|
}
|
||||||
inputs: prompt
|
|
||||||
}
|
}
|
||||||
if (this.endpoint) {
|
}
|
||||||
hf.endpoint(this.endpoint)
|
|
||||||
} else {
|
async *_streamResponseChunks(
|
||||||
obj.model = this.model
|
prompt: string,
|
||||||
|
options: this['ParsedCallOptions'],
|
||||||
|
runManager?: CallbackManagerForLLMRun
|
||||||
|
): AsyncGenerator<GenerationChunk> {
|
||||||
|
const hfi = await this._prepareHFInference()
|
||||||
|
const stream = await this.caller.call(async () =>
|
||||||
|
hfi.textGenerationStream({
|
||||||
|
...this.invocationParams(options),
|
||||||
|
inputs: prompt
|
||||||
|
})
|
||||||
|
)
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
const token = chunk.token.text
|
||||||
|
yield new GenerationChunk({ text: token, generationInfo: chunk })
|
||||||
|
await runManager?.handleLLMNewToken(token ?? '')
|
||||||
|
|
||||||
|
// stream is done
|
||||||
|
if (chunk.generated_text)
|
||||||
|
yield new GenerationChunk({
|
||||||
|
text: '',
|
||||||
|
generationInfo: { finished: true }
|
||||||
|
})
|
||||||
}
|
}
|
||||||
const res = await this.caller.callWithOptions({ signal: options.signal }, hf.textGeneration.bind(hf), obj)
|
}
|
||||||
|
|
||||||
|
/** @ignore */
|
||||||
|
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
|
||||||
|
const hfi = await this._prepareHFInference()
|
||||||
|
const args = { ...this.invocationParams(options), inputs: prompt }
|
||||||
|
const res = await this.caller.callWithOptions({ signal: options.signal }, hfi.textGeneration.bind(hfi), args)
|
||||||
return res.generated_text
|
return res.generated_text
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** @ignore */
|
||||||
|
private async _prepareHFInference() {
|
||||||
|
const { HfInference } = await HuggingFaceInference.imports()
|
||||||
|
const hfi = new HfInference(this.apiKey, {
|
||||||
|
includeCredentials: this.includeCredentials
|
||||||
|
})
|
||||||
|
return this.endpointUrl ? hfi.endpoint(this.endpointUrl) : hfi
|
||||||
|
}
|
||||||
|
|
||||||
/** @ignore */
|
/** @ignore */
|
||||||
static async imports(): Promise<{
|
static async imports(): Promise<{
|
||||||
HfInference: typeof import('@huggingface/inference').HfInference
|
HfInference: typeof import('@huggingface/inference').HfInference
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue