Added bedrock cohere embed model requirements (#2207)

This commit is contained in:
Quinn 2024-04-18 03:51:48 -07:00 committed by GitHub
parent 713077381b
commit e4ab2a9e33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 70 additions and 7 deletions

View File

@ -19,7 +19,7 @@ class AWSBedrockEmbedding_Embeddings implements INode {
constructor() {
this.label = 'AWS Bedrock Embeddings'
this.name = 'AWSBedrockEmbeddings'
this.version = 4.0
this.version = 5.0
this.type = 'AWSBedrockEmbeddings'
this.icon = 'aws.svg'
this.category = 'Embeddings'
@ -53,11 +53,40 @@ class AWSBedrockEmbedding_Embeddings implements INode {
description: 'If provided, will override model selected from Model Name option',
type: 'string',
optional: true
},
{
label: 'Cohere Input Type',
name: 'inputType',
type: 'options',
description:
'Specifies the type of input passed to the model. Required for cohere embedding models v3 and higher. <a target="_blank" href="https://docs.cohere.com/reference/embed">Official Docs</a>',
options: [
{
label: 'search_document',
name: 'search_document',
description: 'Use this to encode documents for embeddings that you store in a vector database for search use-cases'
},
{
label: 'search_query',
name: 'search_query',
description: 'Use this when you query your vector DB to find relevant documents.'
},
{
label: 'classification',
name: 'classification',
description: 'Use this when you use the embeddings as an input to a text classifier'
},
{
label: 'clustering',
name: 'clustering',
description: 'Use this when you want to cluster the embeddings.'
}
],
optional: true
}
]
}
//@ts-ignore
loadMethods = {
async listModels(): Promise<INodeOptionsValue[]> {
return await getModels(MODEL_TYPE.EMBEDDING, 'AWSBedrockEmbeddings')
@ -71,6 +100,11 @@ class AWSBedrockEmbedding_Embeddings implements INode {
const iRegion = nodeData.inputs?.region as string
const iModel = nodeData.inputs?.model as string
const customModel = nodeData.inputs?.customModel as string
const inputType = nodeData.inputs?.inputType as string
if (iModel.startsWith('cohere') && !inputType) {
throw new Error('Input Type must be selected for Cohere models.')
}
const obj: BedrockEmbeddingsParams = {
model: customModel ? customModel : iModel,
@ -97,20 +131,27 @@ class AWSBedrockEmbedding_Embeddings implements INode {
const model = new BedrockEmbeddings(obj)
// Avoid Illegal Invocation
model.embedQuery = async (document: string): Promise<number[]> => {
return await embedText(document, client, iModel)
if (iModel.startsWith('cohere')) {
const embeddings = await embedTextCohere([document], client, iModel, inputType)
return embeddings[0]
} else {
return await embedTextTitan(document, client, iModel)
}
}
model.embedDocuments = async (documents: string[]): Promise<number[][]> => {
return Promise.all(documents.map((document) => embedText(document, client, iModel)))
if (iModel.startsWith('cohere')) {
return await embedTextCohere(documents, client, iModel, inputType)
} else {
return Promise.all(documents.map((document) => embedTextTitan(document, client, iModel)))
}
}
return model
}
}
const embedText = async (text: string, client: BedrockRuntimeClient, model: string): Promise<number[]> => {
// replace newlines, which can negatively affect performance.
const embedTextTitan = async (text: string, client: BedrockRuntimeClient, model: string): Promise<number[]> => {
const cleanedText = text.replace(/\n/g, ' ')
const res = await client.send(
@ -132,4 +173,26 @@ const embedText = async (text: string, client: BedrockRuntimeClient, model: stri
}
}
const embedTextCohere = async (texts: string[], client: BedrockRuntimeClient, model: string, inputType: string): Promise<number[][]> => {
const cleanedTexts = texts.map((text) => text.replace(/\n/g, ' '))
const command = {
modelId: model,
body: JSON.stringify({
texts: cleanedTexts,
input_type: inputType,
truncate: 'END'
}),
contentType: 'application/json',
accept: 'application/json'
}
const res = await client.send(new InvokeModelCommand(command))
try {
const body = new TextDecoder().decode(res.body)
return JSON.parse(body).embeddings
} catch (e) {
throw new Error('An invalid response was returned by Bedrock.')
}
}
module.exports = { nodeClass: AWSBedrockEmbedding_Embeddings }