Compression Retriever - Cohere Rerank - Add max chunks per document as optional parameter

This commit is contained in:
vinodkiran 2024-01-17 18:29:37 +05:30
parent 0d19dc5be4
commit 3407fa92f4
2 changed files with 16 additions and 3 deletions

View File

@ -7,11 +7,13 @@ export class CohereRerank extends BaseDocumentCompressor {
private COHERE_API_URL = 'https://api.cohere.ai/v1/rerank' private COHERE_API_URL = 'https://api.cohere.ai/v1/rerank'
private readonly model: string private readonly model: string
private readonly k: number private readonly k: number
constructor(cohereAPIKey: string, model: string, k: number) { private readonly maxChunksPerDoc: number
constructor(cohereAPIKey: string, model: string, k: number, maxChunksPerDoc: number) {
super() super()
this.cohereAPIKey = cohereAPIKey this.cohereAPIKey = cohereAPIKey
this.model = model this.model = model
this.k = k this.k = k
this.maxChunksPerDoc = maxChunksPerDoc
} }
async compressDocuments( async compressDocuments(
documents: Document<Record<string, any>>[], documents: Document<Record<string, any>>[],
@ -32,7 +34,7 @@ export class CohereRerank extends BaseDocumentCompressor {
const data = { const data = {
model: this.model, model: this.model,
topN: this.k, topN: this.k,
max_chunks_per_doc: 10, max_chunks_per_doc: this.maxChunksPerDoc,
query: query, query: query,
return_documents: false, return_documents: false,
documents: documents.map((doc) => doc.pageContent) documents: documents.map((doc) => doc.pageContent)

View File

@ -67,6 +67,15 @@ class CohereRerankRetriever_Retrievers implements INode {
default: 0, default: 0,
additionalParams: true, additionalParams: true,
optional: true optional: true
},
{
label: 'Max Chunks Per Document',
name: 'maxChunksPerDoc',
placeholder: '10',
type: 'number',
default: 10,
additionalParams: true,
optional: true
} }
] ]
} }
@ -78,12 +87,14 @@ class CohereRerankRetriever_Retrievers implements INode {
const cohereApiKey = getCredentialParam('cohereApiKey', credentialData, nodeData) const cohereApiKey = getCredentialParam('cohereApiKey', credentialData, nodeData)
const topK = nodeData.inputs?.topK as string const topK = nodeData.inputs?.topK as string
let k = topK ? parseFloat(topK) : 4 let k = topK ? parseFloat(topK) : 4
const maxChunks = nodeData.inputs?.maxChunksPerDoc as string
let max = maxChunks ? parseInt(maxChunks) : 10
if (k <= 0) { if (k <= 0) {
k = (baseRetriever as VectorStoreRetriever).k k = (baseRetriever as VectorStoreRetriever).k
} }
const cohereCompressor = new CohereRerank(cohereApiKey, model, k) const cohereCompressor = new CohereRerank(cohereApiKey, model, k, max)
return new ContextualCompressionRetriever({ return new ContextualCompressionRetriever({
baseCompressor: cohereCompressor, baseCompressor: cohereCompressor,
baseRetriever: baseRetriever baseRetriever: baseRetriever