Compression Retriever: Reciprocal Rank Fusion

This commit is contained in:
vinodkiran 2023-12-29 20:35:42 +05:30
parent f6ee137ca3
commit 4dd2f245ff
3 changed files with 186 additions and 0 deletions

View File

@ -0,0 +1,84 @@
import { INode, INodeData, INodeParams } from '../../../src/Interface'
import { BaseLanguageModel } from 'langchain/base_language'
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression'
import { BaseRetriever } from 'langchain/schema/retriever'
import { ReciprocalRankFusion } from './ReciprocalRankFusion'
import { VectorStoreRetriever } from 'langchain/vectorstores/base'
class RRFRetriever_Retrievers implements INode {
label: string
name: string
version: number
description: string
type: string
icon: string
category: string
baseClasses: string[]
inputs: INodeParams[]
badge: string
constructor() {
this.label = 'Reciprocal Rank Fusion Retriever'
this.name = 'RRFRetriever'
this.version = 2.0
this.type = 'RRFRetriever'
this.badge = 'NEW'
this.icon = 'compressionRetriever.svg'
this.category = 'Retrievers'
this.description = 'Reciprocal Rank Fusion to re-rank search results by multiple query generation.'
this.baseClasses = [this.type, 'BaseRetriever']
this.inputs = [
{
label: 'Base Retriever',
name: 'baseRetriever',
type: 'VectorStoreRetriever'
},
{
label: 'Language Model',
name: 'model',
type: 'BaseLanguageModel'
},
{
label: 'Query Count',
name: 'queryCount',
description: 'Number of synthetic queries to generate. Default to 4',
placeholder: '4',
type: 'number',
default: 4,
additionalParams: true,
optional: true
},
{
label: 'Top K',
name: 'topK',
description: 'Number of top results to fetch. Default to the TopK of the Base Retriever',
placeholder: '0',
type: 'number',
default: 0,
additionalParams: true,
optional: true
}
]
}
async init(nodeData: INodeData): Promise<any> {
const llm = nodeData.inputs?.model as BaseLanguageModel
const baseRetriever = nodeData.inputs?.baseRetriever as BaseRetriever
const queryCount = nodeData.inputs?.queryCount as string
const q = queryCount ? parseFloat(queryCount) : 4
const topK = nodeData.inputs?.topK as string
let k = topK ? parseFloat(topK) : 4
if (k <= 0) {
k = (baseRetriever as VectorStoreRetriever).k
}
const ragFusion = new ReciprocalRankFusion(llm, baseRetriever as VectorStoreRetriever, q, k)
return new ContextualCompressionRetriever({
baseCompressor: ragFusion,
baseRetriever: baseRetriever
})
}
}
module.exports = { nodeClass: RRFRetriever_Retrievers }

View File

@ -0,0 +1,95 @@
import { BaseDocumentCompressor } from 'langchain/retrievers/document_compressors'
import { Document } from 'langchain/document'
import { Callbacks } from 'langchain/callbacks'
import { BaseLanguageModel } from 'langchain/base_language'
import { ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate } from 'langchain/prompts'
import { LLMChain } from 'langchain/chains'
import { VectorStoreRetriever } from 'langchain/vectorstores/base'
export class ReciprocalRankFusion extends BaseDocumentCompressor {
private readonly llm: BaseLanguageModel
private readonly queryCount: number
private readonly topK: number
private baseRetriever: VectorStoreRetriever
constructor(llm: BaseLanguageModel, baseRetriever: VectorStoreRetriever, queryCount: number, topK: number) {
super()
this.queryCount = queryCount
this.llm = llm
this.baseRetriever = baseRetriever
this.topK = topK
}
async compressDocuments(
documents: Document<Record<string, any>>[],
query: string,
_?: Callbacks | undefined
): Promise<Document<Record<string, any>>[]> {
// avoid empty api call
if (documents.length === 0) {
return []
}
const chatPrompt = ChatPromptTemplate.fromMessages([
SystemMessagePromptTemplate.fromTemplate(
'You are a helpful assistant that generates multiple search queries based on a single input query.'
),
HumanMessagePromptTemplate.fromTemplate(
'Generate multiple search queries related to: {input}. Provide these alternative questions separated by newlines, do not add any numbers.'
),
HumanMessagePromptTemplate.fromTemplate('OUTPUT (' + this.queryCount + ' queries):')
])
const llmChain = new LLMChain({
llm: this.llm,
prompt: chatPrompt
})
const multipleQueries = await llmChain.call({ input: query })
const queries = []
queries.push(query)
multipleQueries.text.split('\n').map((q: string) => {
queries.push(q)
})
console.log(JSON.stringify(queries))
const docList: Document<Record<string, any>>[][] = []
for (let i = 0; i < queries.length; i++) {
const resultOne = await this.baseRetriever.vectorStore.similaritySearch(queries[i], 5)
const docs: any[] = []
resultOne.forEach((doc) => {
docs.push(doc)
})
docList.push(docs)
}
return this.reciprocalRankFunction(docList, 60)
}
reciprocalRankFunction(docList: Document<Record<string, any>>[][], k: number): Document<Record<string, any>>[] {
docList.forEach((docs: Document<Record<string, any>>[]) => {
docs.forEach((doc: any, index: number) => {
let rank = index + 1
if (doc.metadata.relevancy_score) {
doc.metadata.relevancy_score += 1 / (rank + k)
} else {
doc.metadata.relevancy_score = 1 / (rank + k)
}
})
})
const scoreArray: any[] = []
docList.forEach((docs: Document<Record<string, any>>[]) => {
docs.forEach((doc: any) => {
scoreArray.push(doc.metadata.relevancy_score)
})
})
scoreArray.sort((a, b) => b - a)
const rerankedDocuments: Document<Record<string, any>>[] = []
const seenScores: any[] = []
scoreArray.forEach((score) => {
docList.forEach((docs) => {
docs.forEach((doc: any) => {
if (doc.metadata.relevancy_score === score && seenScores.indexOf(score) === -1) {
rerankedDocuments.push(doc)
seenScores.push(doc.metadata.relevancy_score)
}
})
})
})
return rerankedDocuments.splice(0, this.topK)
}
}

View File

@ -0,0 +1,7 @@
<svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-chart-bar" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<path d="M3 12m0 1a1 1 0 0 1 1 -1h4a1 1 0 0 1 1 1v6a1 1 0 0 1 -1 1h-4a1 1 0 0 1 -1 -1z" />
<path d="M9 8m0 1a1 1 0 0 1 1 -1h4a1 1 0 0 1 1 1v10a1 1 0 0 1 -1 1h-4a1 1 0 0 1 -1 -1z" />
<path d="M15 4m0 1a1 1 0 0 1 1 -1h4a1 1 0 0 1 1 1v14a1 1 0 0 1 -1 1h-4a1 1 0 0 1 -1 -1z" />
<path d="M4 20l14 0" />
</svg>

After

Width:  |  Height:  |  Size: 600 B