diff --git a/packages/components/nodes/embeddings/AWSBedrockEmbedding/AWSBedrockEmbedding.ts b/packages/components/nodes/embeddings/AWSBedrockEmbedding/AWSBedrockEmbedding.ts index 2b8fc8060..5c5f6352b 100644 --- a/packages/components/nodes/embeddings/AWSBedrockEmbedding/AWSBedrockEmbedding.ts +++ b/packages/components/nodes/embeddings/AWSBedrockEmbedding/AWSBedrockEmbedding.ts @@ -19,7 +19,7 @@ class AWSBedrockEmbedding_Embeddings implements INode { constructor() { this.label = 'AWS Bedrock Embeddings' this.name = 'AWSBedrockEmbeddings' - this.version = 4.0 + this.version = 5.0 this.type = 'AWSBedrockEmbeddings' this.icon = 'aws.svg' this.category = 'Embeddings' @@ -53,11 +53,40 @@ class AWSBedrockEmbedding_Embeddings implements INode { description: 'If provided, will override model selected from Model Name option', type: 'string', optional: true + }, + { + label: 'Cohere Input Type', + name: 'inputType', + type: 'options', + description: + 'Specifies the type of input passed to the model. Required for cohere embedding models v3 and higher. Official Docs', + options: [ + { + label: 'search_document', + name: 'search_document', + description: 'Use this to encode documents for embeddings that you store in a vector database for search use-cases' + }, + { + label: 'search_query', + name: 'search_query', + description: 'Use this when you query your vector DB to find relevant documents.' + }, + { + label: 'classification', + name: 'classification', + description: 'Use this when you use the embeddings as an input to a text classifier' + }, + { + label: 'clustering', + name: 'clustering', + description: 'Use this when you want to cluster the embeddings.' + } + ], + optional: true } ] } - //@ts-ignore loadMethods = { async listModels(): Promise { return await getModels(MODEL_TYPE.EMBEDDING, 'AWSBedrockEmbeddings') @@ -71,6 +100,11 @@ class AWSBedrockEmbedding_Embeddings implements INode { const iRegion = nodeData.inputs?.region as string const iModel = nodeData.inputs?.model as string const customModel = nodeData.inputs?.customModel as string + const inputType = nodeData.inputs?.inputType as string + + if (iModel.startsWith('cohere') && !inputType) { + throw new Error('Input Type must be selected for Cohere models.') + } const obj: BedrockEmbeddingsParams = { model: customModel ? customModel : iModel, @@ -97,20 +131,27 @@ class AWSBedrockEmbedding_Embeddings implements INode { const model = new BedrockEmbeddings(obj) - // Avoid Illegal Invocation model.embedQuery = async (document: string): Promise => { - return await embedText(document, client, iModel) + if (iModel.startsWith('cohere')) { + const embeddings = await embedTextCohere([document], client, iModel, inputType) + return embeddings[0] + } else { + return await embedTextTitan(document, client, iModel) + } } model.embedDocuments = async (documents: string[]): Promise => { - return Promise.all(documents.map((document) => embedText(document, client, iModel))) + if (iModel.startsWith('cohere')) { + return await embedTextCohere(documents, client, iModel, inputType) + } else { + return Promise.all(documents.map((document) => embedTextTitan(document, client, iModel))) + } } return model } } -const embedText = async (text: string, client: BedrockRuntimeClient, model: string): Promise => { - // replace newlines, which can negatively affect performance. +const embedTextTitan = async (text: string, client: BedrockRuntimeClient, model: string): Promise => { const cleanedText = text.replace(/\n/g, ' ') const res = await client.send( @@ -132,4 +173,26 @@ const embedText = async (text: string, client: BedrockRuntimeClient, model: stri } } +const embedTextCohere = async (texts: string[], client: BedrockRuntimeClient, model: string, inputType: string): Promise => { + const cleanedTexts = texts.map((text) => text.replace(/\n/g, ' ')) + + const command = { + modelId: model, + body: JSON.stringify({ + texts: cleanedTexts, + input_type: inputType, + truncate: 'END' + }), + contentType: 'application/json', + accept: 'application/json' + } + const res = await client.send(new InvokeModelCommand(command)) + try { + const body = new TextDecoder().decode(res.body) + return JSON.parse(body).embeddings + } catch (e) { + throw new Error('An invalid response was returned by Bedrock.') + } +} + module.exports = { nodeClass: AWSBedrockEmbedding_Embeddings }