diff --git a/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerank.ts b/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerank.ts index 612581edb..55f3c4aad 100644 --- a/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerank.ts +++ b/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerank.ts @@ -5,12 +5,13 @@ import axios from 'axios' export class CohereRerank extends BaseDocumentCompressor { private cohereAPIKey: any private COHERE_API_URL = 'https://api.cohere.ai/v1/rerank' - private model: string - - constructor(cohereAPIKey: string, model: string) { + private readonly model: string + private readonly k: number + constructor(cohereAPIKey: string, model: string, k: number) { super() this.cohereAPIKey = cohereAPIKey this.model = model + this.k = k } async compressDocuments( documents: Document>[], @@ -30,6 +31,7 @@ export class CohereRerank extends BaseDocumentCompressor { } const data = { model: this.model, + topN: this.k, max_chunks_per_doc: 10, query: query, return_documents: false, diff --git a/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerankRetriever.ts b/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerankRetriever.ts index 2e7090bcf..3c1872b3f 100644 --- a/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerankRetriever.ts +++ b/packages/components/nodes/retrievers/CohereRerankRetriever/CohereRerankRetriever.ts @@ -3,6 +3,7 @@ import { BaseRetriever } from 'langchain/schema/retriever' import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression' import { getCredentialData, getCredentialParam } from '../../../src' import { CohereRerank } from './CohereRerank' +import { VectorStoreRetriever } from 'langchain/vectorstores/base' class CohereRerankRetriever_Retrievers implements INode { label: string @@ -56,6 +57,16 @@ class CohereRerankRetriever_Retrievers implements INode { ], default: 'rerank-english-v2.0', optional: true + }, + { + label: 'Top K', + name: 'topK', + description: 'Number of top results to fetch. Default to the TopK of the Base Retriever', + placeholder: '0', + type: 'number', + default: 0, + additionalParams: true, + optional: true } ] } @@ -65,8 +76,14 @@ class CohereRerankRetriever_Retrievers implements INode { const model = nodeData.inputs?.model as string const credentialData = await getCredentialData(nodeData.credential ?? '', options) const cohereApiKey = getCredentialParam('cohereApiKey', credentialData, nodeData) + const topK = nodeData.inputs?.topK as string + let k = topK ? parseFloat(topK) : 4 - const cohereCompressor = new CohereRerank(cohereApiKey, model) + if (k <= 0) { + k = (baseRetriever as VectorStoreRetriever).k + } + + const cohereCompressor = new CohereRerank(cohereApiKey, model, k) return new ContextualCompressionRetriever({ baseCompressor: cohereCompressor, baseRetriever: baseRetriever