import { BaseDocumentCompressor } from 'langchain/retrievers/document_compressors' import { Document } from 'langchain/document' import { Callbacks } from 'langchain/callbacks' import { BaseLanguageModel } from 'langchain/base_language' import { ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate } from 'langchain/prompts' import { LLMChain } from 'langchain/chains' import { VectorStoreRetriever } from 'langchain/vectorstores/base' export class ReciprocalRankFusion extends BaseDocumentCompressor { private readonly llm: BaseLanguageModel private readonly queryCount: number private readonly topK: number private readonly c: number private baseRetriever: VectorStoreRetriever constructor(llm: BaseLanguageModel, baseRetriever: VectorStoreRetriever, queryCount: number, topK: number, c: number) { super() this.queryCount = queryCount this.llm = llm this.baseRetriever = baseRetriever this.topK = topK this.c = c } async compressDocuments( documents: Document>[], query: string, _?: Callbacks | undefined ): Promise>[]> { // avoid empty api call if (documents.length === 0) { return [] } const chatPrompt = ChatPromptTemplate.fromMessages([ SystemMessagePromptTemplate.fromTemplate( 'You are a helpful assistant that generates multiple search queries based on a single input query.' ), HumanMessagePromptTemplate.fromTemplate( 'Generate multiple search queries related to: {input}. Provide these alternative questions separated by newlines, do not add any numbers.' ), HumanMessagePromptTemplate.fromTemplate('OUTPUT (' + this.queryCount + ' queries):') ]) const llmChain = new LLMChain({ llm: this.llm, prompt: chatPrompt }) const multipleQueries = await llmChain.call({ input: query }) const queries = [] queries.push(query) multipleQueries.text.split('\n').map((q: string) => { queries.push(q) }) const docList: Document>[][] = [] for (let i = 0; i < queries.length; i++) { const resultOne = await this.baseRetriever.vectorStore.similaritySearch(queries[i], 5) const docs: any[] = [] resultOne.forEach((doc) => { docs.push(doc) }) docList.push(docs) } return this.reciprocalRankFunction(docList, this.c) } reciprocalRankFunction(docList: Document>[][], k: number): Document>[] { docList.forEach((docs: Document>[]) => { docs.forEach((doc: any, index: number) => { let rank = index + 1 if (doc.metadata.relevancy_score) { doc.metadata.relevancy_score += 1 / (rank + k) } else { doc.metadata.relevancy_score = 1 / (rank + k) } }) }) const scoreArray: any[] = [] docList.forEach((docs: Document>[]) => { docs.forEach((doc: any) => { scoreArray.push(doc.metadata.relevancy_score) }) }) scoreArray.sort((a, b) => b - a) const rerankedDocuments: Document>[] = [] const seenScores: any[] = [] scoreArray.forEach((score) => { docList.forEach((docs) => { docs.forEach((doc: any) => { if (doc.metadata.relevancy_score === score && seenScores.indexOf(score) === -1) { rerankedDocuments.push(doc) seenScores.push(doc.metadata.relevancy_score) } }) }) }) return rerankedDocuments.splice(0, this.topK) } }