fix: Upgrade Hugging Face Inference API to support Inference Providers (#5454)
- Upgrade @huggingface/inference from v2.6.1 to v4.13.2 - Update ChatHuggingFace to use InferenceClient with chatCompletion API - Update HuggingFaceInference (LLM) to use v4 HfInference with Inference Providers - Update HuggingFaceInferenceEmbedding to use v4 HfInference - Add endpoint handling logic to ignore custom endpoints for provider-based models - Add improved error handling and validation for API keys - Update UI descriptions to guide users on proper configuration Fixes #5161 Co-authored-by: Henry <hzj94@hotmail.com>
This commit is contained in:
parent
097404f24a
commit
0cc7b3036e
|
|
@ -1569,16 +1569,20 @@ class Agent_Agentflow implements INode {
|
||||||
for await (const chunk of await llmNodeInstance.stream(messages, { signal: abortController?.signal })) {
|
for await (const chunk of await llmNodeInstance.stream(messages, { signal: abortController?.signal })) {
|
||||||
if (sseStreamer) {
|
if (sseStreamer) {
|
||||||
let content = ''
|
let content = ''
|
||||||
if (Array.isArray(chunk.content) && chunk.content.length > 0) {
|
|
||||||
|
if (typeof chunk === 'string') {
|
||||||
|
content = chunk
|
||||||
|
} else if (Array.isArray(chunk.content) && chunk.content.length > 0) {
|
||||||
const contents = chunk.content as MessageContentText[]
|
const contents = chunk.content as MessageContentText[]
|
||||||
content = contents.map((item) => item.text).join('')
|
content = contents.map((item) => item.text).join('')
|
||||||
} else {
|
} else if (chunk.content) {
|
||||||
content = chunk.content.toString()
|
content = chunk.content.toString()
|
||||||
}
|
}
|
||||||
sseStreamer.streamTokenEvent(chatId, content)
|
sseStreamer.streamTokenEvent(chatId, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
response = response.concat(chunk)
|
const messageChunk = typeof chunk === 'string' ? new AIMessageChunk(chunk) : chunk
|
||||||
|
response = response.concat(messageChunk)
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error during streaming:', error)
|
console.error('Error during streaming:', error)
|
||||||
|
|
|
||||||
|
|
@ -241,8 +241,11 @@ class HumanInput_Agentflow implements INode {
|
||||||
if (isStreamable) {
|
if (isStreamable) {
|
||||||
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
|
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
|
||||||
for await (const chunk of await llmNodeInstance.stream(messages)) {
|
for await (const chunk of await llmNodeInstance.stream(messages)) {
|
||||||
sseStreamer.streamTokenEvent(chatId, chunk.content.toString())
|
const content = typeof chunk === 'string' ? chunk : chunk.content.toString()
|
||||||
response = response.concat(chunk)
|
sseStreamer.streamTokenEvent(chatId, content)
|
||||||
|
|
||||||
|
const messageChunk = typeof chunk === 'string' ? new AIMessageChunk(chunk) : chunk
|
||||||
|
response = response.concat(messageChunk)
|
||||||
}
|
}
|
||||||
humanInputDescription = response.content as string
|
humanInputDescription = response.content as string
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -824,16 +824,20 @@ class LLM_Agentflow implements INode {
|
||||||
for await (const chunk of await llmNodeInstance.stream(messages, { signal: abortController?.signal })) {
|
for await (const chunk of await llmNodeInstance.stream(messages, { signal: abortController?.signal })) {
|
||||||
if (sseStreamer) {
|
if (sseStreamer) {
|
||||||
let content = ''
|
let content = ''
|
||||||
if (Array.isArray(chunk.content) && chunk.content.length > 0) {
|
|
||||||
|
if (typeof chunk === 'string') {
|
||||||
|
content = chunk
|
||||||
|
} else if (Array.isArray(chunk.content) && chunk.content.length > 0) {
|
||||||
const contents = chunk.content as MessageContentText[]
|
const contents = chunk.content as MessageContentText[]
|
||||||
content = contents.map((item) => item.text).join('')
|
content = contents.map((item) => item.text).join('')
|
||||||
} else {
|
} else if (chunk.content) {
|
||||||
content = chunk.content.toString()
|
content = chunk.content.toString()
|
||||||
}
|
}
|
||||||
sseStreamer.streamTokenEvent(chatId, content)
|
sseStreamer.streamTokenEvent(chatId, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
response = response.concat(chunk)
|
const messageChunk = typeof chunk === 'string' ? new AIMessageChunk(chunk) : chunk
|
||||||
|
response = response.concat(messageChunk)
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error during streaming:', error)
|
console.error('Error during streaming:', error)
|
||||||
|
|
|
||||||
|
|
@ -41,15 +41,17 @@ class ChatHuggingFace_ChatModels implements INode {
|
||||||
label: 'Model',
|
label: 'Model',
|
||||||
name: 'model',
|
name: 'model',
|
||||||
type: 'string',
|
type: 'string',
|
||||||
description: 'If using own inference endpoint, leave this blank',
|
description:
|
||||||
placeholder: 'gpt2'
|
'Model name (e.g., deepseek-ai/DeepSeek-V3.2-Exp:novita). If model includes provider (:) or using router endpoint, leave Endpoint blank.',
|
||||||
|
placeholder: 'deepseek-ai/DeepSeek-V3.2-Exp:novita'
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
label: 'Endpoint',
|
label: 'Endpoint',
|
||||||
name: 'endpoint',
|
name: 'endpoint',
|
||||||
type: 'string',
|
type: 'string',
|
||||||
placeholder: 'https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2',
|
placeholder: 'https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2',
|
||||||
description: 'Using your own inference endpoint',
|
description:
|
||||||
|
'Custom inference endpoint (optional). Not needed for models with providers (:) or router endpoints. Leave blank to use Inference Providers.',
|
||||||
optional: true
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -124,6 +126,15 @@ class ChatHuggingFace_ChatModels implements INode {
|
||||||
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
|
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
|
||||||
const huggingFaceApiKey = getCredentialParam('huggingFaceApiKey', credentialData, nodeData)
|
const huggingFaceApiKey = getCredentialParam('huggingFaceApiKey', credentialData, nodeData)
|
||||||
|
|
||||||
|
if (!huggingFaceApiKey) {
|
||||||
|
console.error('[ChatHuggingFace] API key validation failed: No API key found')
|
||||||
|
throw new Error('HuggingFace API key is required. Please configure it in the credential settings.')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!huggingFaceApiKey.startsWith('hf_')) {
|
||||||
|
console.warn('[ChatHuggingFace] API key format warning: Key does not start with "hf_"')
|
||||||
|
}
|
||||||
|
|
||||||
const obj: Partial<HFInput> = {
|
const obj: Partial<HFInput> = {
|
||||||
model,
|
model,
|
||||||
apiKey: huggingFaceApiKey
|
apiKey: huggingFaceApiKey
|
||||||
|
|
|
||||||
|
|
@ -56,9 +56,9 @@ export class HuggingFaceInference extends LLM implements HFInput {
|
||||||
this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY')
|
this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY')
|
||||||
this.endpointUrl = fields?.endpointUrl
|
this.endpointUrl = fields?.endpointUrl
|
||||||
this.includeCredentials = fields?.includeCredentials
|
this.includeCredentials = fields?.includeCredentials
|
||||||
if (!this.apiKey) {
|
if (!this.apiKey || this.apiKey.trim() === '') {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
'Please set an API key for HuggingFace Hub in the environment variable HUGGINGFACEHUB_API_KEY or in the apiKey field of the HuggingFaceInference constructor.'
|
'Please set an API key for HuggingFace Hub. Either configure it in the credential settings in the UI, or set the environment variable HUGGINGFACEHUB_API_KEY.'
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -68,19 +68,21 @@ export class HuggingFaceInference extends LLM implements HFInput {
|
||||||
}
|
}
|
||||||
|
|
||||||
invocationParams(options?: this['ParsedCallOptions']) {
|
invocationParams(options?: this['ParsedCallOptions']) {
|
||||||
return {
|
// Return parameters compatible with chatCompletion API (OpenAI-compatible format)
|
||||||
model: this.model,
|
const params: any = {
|
||||||
parameters: {
|
|
||||||
// make it behave similar to openai, returning only the generated text
|
|
||||||
return_full_text: false,
|
|
||||||
temperature: this.temperature,
|
temperature: this.temperature,
|
||||||
max_new_tokens: this.maxTokens,
|
max_tokens: this.maxTokens,
|
||||||
stop: options?.stop ?? this.stopSequences,
|
stop: options?.stop ?? this.stopSequences,
|
||||||
top_p: this.topP,
|
top_p: this.topP
|
||||||
top_k: this.topK,
|
|
||||||
repetition_penalty: this.frequencyPenalty
|
|
||||||
}
|
}
|
||||||
|
// Include optional parameters if they are defined
|
||||||
|
if (this.topK !== undefined) {
|
||||||
|
params.top_k = this.topK
|
||||||
}
|
}
|
||||||
|
if (this.frequencyPenalty !== undefined) {
|
||||||
|
params.frequency_penalty = this.frequencyPenalty
|
||||||
|
}
|
||||||
|
return params
|
||||||
}
|
}
|
||||||
|
|
||||||
async *_streamResponseChunks(
|
async *_streamResponseChunks(
|
||||||
|
|
@ -88,51 +90,109 @@ export class HuggingFaceInference extends LLM implements HFInput {
|
||||||
options: this['ParsedCallOptions'],
|
options: this['ParsedCallOptions'],
|
||||||
runManager?: CallbackManagerForLLMRun
|
runManager?: CallbackManagerForLLMRun
|
||||||
): AsyncGenerator<GenerationChunk> {
|
): AsyncGenerator<GenerationChunk> {
|
||||||
const hfi = await this._prepareHFInference()
|
try {
|
||||||
|
const client = await this._prepareHFInference()
|
||||||
const stream = await this.caller.call(async () =>
|
const stream = await this.caller.call(async () =>
|
||||||
hfi.textGenerationStream({
|
client.chatCompletionStream({
|
||||||
...this.invocationParams(options),
|
model: this.model,
|
||||||
inputs: prompt
|
messages: [{ role: 'user', content: prompt }],
|
||||||
|
...this.invocationParams(options)
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
const token = chunk.token.text
|
const token = chunk.choices[0]?.delta?.content || ''
|
||||||
|
if (token) {
|
||||||
yield new GenerationChunk({ text: token, generationInfo: chunk })
|
yield new GenerationChunk({ text: token, generationInfo: chunk })
|
||||||
await runManager?.handleLLMNewToken(token ?? '')
|
await runManager?.handleLLMNewToken(token)
|
||||||
|
}
|
||||||
// stream is done
|
// stream is done when finish_reason is set
|
||||||
if (chunk.generated_text)
|
if (chunk.choices[0]?.finish_reason) {
|
||||||
yield new GenerationChunk({
|
yield new GenerationChunk({
|
||||||
text: '',
|
text: '',
|
||||||
generationInfo: { finished: true }
|
generationInfo: { finished: true }
|
||||||
})
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
console.error('[ChatHuggingFace] Error in _streamResponseChunks:', error)
|
||||||
|
// Provide more helpful error messages
|
||||||
|
if (error?.message?.includes('endpointUrl') || error?.message?.includes('third-party provider')) {
|
||||||
|
throw new Error(
|
||||||
|
`Cannot use custom endpoint with model "${this.model}" that includes a provider. Please leave the Endpoint field blank in the UI. Original error: ${error.message}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @ignore */
|
/** @ignore */
|
||||||
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
|
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
|
||||||
const hfi = await this._prepareHFInference()
|
try {
|
||||||
const args = { ...this.invocationParams(options), inputs: prompt }
|
const client = await this._prepareHFInference()
|
||||||
const res = await this.caller.callWithOptions({ signal: options.signal }, hfi.textGeneration.bind(hfi), args)
|
// Use chatCompletion for chat models (v4 supports conversational models via Inference Providers)
|
||||||
return res.generated_text
|
const args = {
|
||||||
|
model: this.model,
|
||||||
|
messages: [{ role: 'user', content: prompt }],
|
||||||
|
...this.invocationParams(options)
|
||||||
|
}
|
||||||
|
const res = await this.caller.callWithOptions({ signal: options.signal }, client.chatCompletion.bind(client), args)
|
||||||
|
const content = res.choices[0]?.message?.content || ''
|
||||||
|
if (!content) {
|
||||||
|
console.error('[ChatHuggingFace] No content in response:', JSON.stringify(res))
|
||||||
|
throw new Error(`No content received from HuggingFace API. Response: ${JSON.stringify(res)}`)
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
} catch (error: any) {
|
||||||
|
console.error('[ChatHuggingFace] Error in _call:', error.message)
|
||||||
|
// Provide more helpful error messages
|
||||||
|
if (error?.message?.includes('endpointUrl') || error?.message?.includes('third-party provider')) {
|
||||||
|
throw new Error(
|
||||||
|
`Cannot use custom endpoint with model "${this.model}" that includes a provider. Please leave the Endpoint field blank in the UI. Original error: ${error.message}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (error?.message?.includes('Invalid username or password') || error?.message?.includes('authentication')) {
|
||||||
|
throw new Error(
|
||||||
|
`HuggingFace API authentication failed. Please verify your API key is correct and starts with "hf_". Original error: ${error.message}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @ignore */
|
/** @ignore */
|
||||||
private async _prepareHFInference() {
|
private async _prepareHFInference() {
|
||||||
const { HfInference } = await HuggingFaceInference.imports()
|
if (!this.apiKey || this.apiKey.trim() === '') {
|
||||||
const hfi = new HfInference(this.apiKey, {
|
console.error('[ChatHuggingFace] API key validation failed: Empty or undefined')
|
||||||
includeCredentials: this.includeCredentials
|
throw new Error('HuggingFace API key is required. Please configure it in the credential settings.')
|
||||||
})
|
}
|
||||||
return this.endpointUrl ? hfi.endpoint(this.endpointUrl) : hfi
|
|
||||||
|
const { InferenceClient } = await HuggingFaceInference.imports()
|
||||||
|
// Use InferenceClient for chat models (works better with Inference Providers)
|
||||||
|
const client = new InferenceClient(this.apiKey)
|
||||||
|
|
||||||
|
// Don't override endpoint if model uses a provider (contains ':') or if endpoint is router-based
|
||||||
|
// When using Inference Providers, endpoint should be left blank - InferenceClient handles routing automatically
|
||||||
|
if (
|
||||||
|
this.endpointUrl &&
|
||||||
|
!this.model.includes(':') &&
|
||||||
|
!this.endpointUrl.includes('/v1/chat/completions') &&
|
||||||
|
!this.endpointUrl.includes('router.huggingface.co')
|
||||||
|
) {
|
||||||
|
return client.endpoint(this.endpointUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return client without endpoint override - InferenceClient will use Inference Providers automatically
|
||||||
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @ignore */
|
/** @ignore */
|
||||||
static async imports(): Promise<{
|
static async imports(): Promise<{
|
||||||
HfInference: typeof import('@huggingface/inference').HfInference
|
InferenceClient: typeof import('@huggingface/inference').InferenceClient
|
||||||
}> {
|
}> {
|
||||||
try {
|
try {
|
||||||
const { HfInference } = await import('@huggingface/inference')
|
const { InferenceClient } = await import('@huggingface/inference')
|
||||||
return { HfInference }
|
return { InferenceClient }
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
throw new Error('Please install huggingface as a dependency with, e.g. `pnpm install @huggingface/inference`')
|
throw new Error('Please install huggingface as a dependency with, e.g. `pnpm install @huggingface/inference`')
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,24 +23,22 @@ export class HuggingFaceInferenceEmbeddings extends Embeddings implements Huggin
|
||||||
this.model = fields?.model ?? 'sentence-transformers/distilbert-base-nli-mean-tokens'
|
this.model = fields?.model ?? 'sentence-transformers/distilbert-base-nli-mean-tokens'
|
||||||
this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY')
|
this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY')
|
||||||
this.endpoint = fields?.endpoint ?? ''
|
this.endpoint = fields?.endpoint ?? ''
|
||||||
this.client = new HfInference(this.apiKey)
|
const hf = new HfInference(this.apiKey)
|
||||||
if (this.endpoint) this.client.endpoint(this.endpoint)
|
// v4 uses Inference Providers by default; only override if custom endpoint provided
|
||||||
|
this.client = this.endpoint ? hf.endpoint(this.endpoint) : hf
|
||||||
}
|
}
|
||||||
|
|
||||||
async _embed(texts: string[]): Promise<number[][]> {
|
async _embed(texts: string[]): Promise<number[][]> {
|
||||||
// replace newlines, which can negatively affect performance.
|
// replace newlines, which can negatively affect performance.
|
||||||
const clean = texts.map((text) => text.replace(/\n/g, ' '))
|
const clean = texts.map((text) => text.replace(/\n/g, ' '))
|
||||||
const hf = new HfInference(this.apiKey)
|
|
||||||
const obj: any = {
|
const obj: any = {
|
||||||
inputs: clean
|
inputs: clean
|
||||||
}
|
}
|
||||||
if (this.endpoint) {
|
if (!this.endpoint) {
|
||||||
hf.endpoint(this.endpoint)
|
|
||||||
} else {
|
|
||||||
obj.model = this.model
|
obj.model = this.model
|
||||||
}
|
}
|
||||||
|
|
||||||
const res = await this.caller.callWithOptions({}, hf.featureExtraction.bind(hf), obj)
|
const res = await this.caller.callWithOptions({}, this.client.featureExtraction.bind(this.client), obj)
|
||||||
return res as number[][]
|
return res as number[][]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,8 @@ export class HuggingFaceInference extends LLM implements HFInput {
|
||||||
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
|
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
|
||||||
const { HfInference } = await HuggingFaceInference.imports()
|
const { HfInference } = await HuggingFaceInference.imports()
|
||||||
const hf = new HfInference(this.apiKey)
|
const hf = new HfInference(this.apiKey)
|
||||||
|
// v4 uses Inference Providers by default; only override if custom endpoint provided
|
||||||
|
const hfClient = this.endpoint ? hf.endpoint(this.endpoint) : hf
|
||||||
const obj: any = {
|
const obj: any = {
|
||||||
parameters: {
|
parameters: {
|
||||||
// make it behave similar to openai, returning only the generated text
|
// make it behave similar to openai, returning only the generated text
|
||||||
|
|
@ -90,12 +92,10 @@ export class HuggingFaceInference extends LLM implements HFInput {
|
||||||
},
|
},
|
||||||
inputs: prompt
|
inputs: prompt
|
||||||
}
|
}
|
||||||
if (this.endpoint) {
|
if (!this.endpoint) {
|
||||||
hf.endpoint(this.endpoint)
|
|
||||||
} else {
|
|
||||||
obj.model = this.model
|
obj.model = this.model
|
||||||
}
|
}
|
||||||
const res = await this.caller.callWithOptions({ signal: options.signal }, hf.textGeneration.bind(hf), obj)
|
const res = await this.caller.callWithOptions({ signal: options.signal }, hfClient.textGeneration.bind(hfClient), obj)
|
||||||
return res.generated_text
|
return res.generated_text
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@
|
||||||
"@google-cloud/storage": "^7.15.2",
|
"@google-cloud/storage": "^7.15.2",
|
||||||
"@google/generative-ai": "^0.24.0",
|
"@google/generative-ai": "^0.24.0",
|
||||||
"@grpc/grpc-js": "^1.10.10",
|
"@grpc/grpc-js": "^1.10.10",
|
||||||
"@huggingface/inference": "^2.6.1",
|
"@huggingface/inference": "^4.13.2",
|
||||||
"@langchain/anthropic": "0.3.33",
|
"@langchain/anthropic": "0.3.33",
|
||||||
"@langchain/aws": "^0.1.11",
|
"@langchain/aws": "^0.1.11",
|
||||||
"@langchain/baidu-qianfan": "^0.1.0",
|
"@langchain/baidu-qianfan": "^0.1.0",
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue