Compression Retriever: Addition of topK to Cohere Rerank Retriever
This commit is contained in:
parent
4dd2f245ff
commit
d0ab21e733
|
|
@ -5,12 +5,13 @@ import axios from 'axios'
|
||||||
export class CohereRerank extends BaseDocumentCompressor {
|
export class CohereRerank extends BaseDocumentCompressor {
|
||||||
private cohereAPIKey: any
|
private cohereAPIKey: any
|
||||||
private COHERE_API_URL = 'https://api.cohere.ai/v1/rerank'
|
private COHERE_API_URL = 'https://api.cohere.ai/v1/rerank'
|
||||||
private model: string
|
private readonly model: string
|
||||||
|
private readonly k: number
|
||||||
constructor(cohereAPIKey: string, model: string) {
|
constructor(cohereAPIKey: string, model: string, k: number) {
|
||||||
super()
|
super()
|
||||||
this.cohereAPIKey = cohereAPIKey
|
this.cohereAPIKey = cohereAPIKey
|
||||||
this.model = model
|
this.model = model
|
||||||
|
this.k = k
|
||||||
}
|
}
|
||||||
async compressDocuments(
|
async compressDocuments(
|
||||||
documents: Document<Record<string, any>>[],
|
documents: Document<Record<string, any>>[],
|
||||||
|
|
@ -30,6 +31,7 @@ export class CohereRerank extends BaseDocumentCompressor {
|
||||||
}
|
}
|
||||||
const data = {
|
const data = {
|
||||||
model: this.model,
|
model: this.model,
|
||||||
|
topN: this.k,
|
||||||
max_chunks_per_doc: 10,
|
max_chunks_per_doc: 10,
|
||||||
query: query,
|
query: query,
|
||||||
return_documents: false,
|
return_documents: false,
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import { BaseRetriever } from 'langchain/schema/retriever'
|
||||||
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression'
|
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression'
|
||||||
import { getCredentialData, getCredentialParam } from '../../../src'
|
import { getCredentialData, getCredentialParam } from '../../../src'
|
||||||
import { CohereRerank } from './CohereRerank'
|
import { CohereRerank } from './CohereRerank'
|
||||||
|
import { VectorStoreRetriever } from 'langchain/vectorstores/base'
|
||||||
|
|
||||||
class CohereRerankRetriever_Retrievers implements INode {
|
class CohereRerankRetriever_Retrievers implements INode {
|
||||||
label: string
|
label: string
|
||||||
|
|
@ -56,6 +57,16 @@ class CohereRerankRetriever_Retrievers implements INode {
|
||||||
],
|
],
|
||||||
default: 'rerank-english-v2.0',
|
default: 'rerank-english-v2.0',
|
||||||
optional: true
|
optional: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Top K',
|
||||||
|
name: 'topK',
|
||||||
|
description: 'Number of top results to fetch. Default to the TopK of the Base Retriever',
|
||||||
|
placeholder: '0',
|
||||||
|
type: 'number',
|
||||||
|
default: 0,
|
||||||
|
additionalParams: true,
|
||||||
|
optional: true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
@ -65,8 +76,14 @@ class CohereRerankRetriever_Retrievers implements INode {
|
||||||
const model = nodeData.inputs?.model as string
|
const model = nodeData.inputs?.model as string
|
||||||
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
|
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
|
||||||
const cohereApiKey = getCredentialParam('cohereApiKey', credentialData, nodeData)
|
const cohereApiKey = getCredentialParam('cohereApiKey', credentialData, nodeData)
|
||||||
|
const topK = nodeData.inputs?.topK as string
|
||||||
|
let k = topK ? parseFloat(topK) : 4
|
||||||
|
|
||||||
const cohereCompressor = new CohereRerank(cohereApiKey, model)
|
if (k <= 0) {
|
||||||
|
k = (baseRetriever as VectorStoreRetriever).k
|
||||||
|
}
|
||||||
|
|
||||||
|
const cohereCompressor = new CohereRerank(cohereApiKey, model, k)
|
||||||
return new ContextualCompressionRetriever({
|
return new ContextualCompressionRetriever({
|
||||||
baseCompressor: cohereCompressor,
|
baseCompressor: cohereCompressor,
|
||||||
baseRetriever: baseRetriever
|
baseRetriever: baseRetriever
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue