From 9c2203be629e3ad5574259d02b7c6739eb73b745 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nguy=E1=BB=85n=20=C4=90=E1=BB=A9c=20H=C3=B9ng?=
<71268621+nguyenhung10012003@users.noreply.github.com>
Date: Mon, 20 Jan 2025 19:33:42 +0700
Subject: [PATCH] Feature: Add Jina AI Rerank Retriever (#3898)
---
.../retrievers/JinaRerankRetriever/JinaAI.svg | 5 +
.../JinaRerankRetriever/JinaRerank.ts | 51 +++++++
.../JinaRerankRetriever.ts | 129 ++++++++++++++++++
3 files changed, 185 insertions(+)
create mode 100644 packages/components/nodes/retrievers/JinaRerankRetriever/JinaAI.svg
create mode 100644 packages/components/nodes/retrievers/JinaRerankRetriever/JinaRerank.ts
create mode 100644 packages/components/nodes/retrievers/JinaRerankRetriever/JinaRerankRetriever.ts
diff --git a/packages/components/nodes/retrievers/JinaRerankRetriever/JinaAI.svg b/packages/components/nodes/retrievers/JinaRerankRetriever/JinaAI.svg
new file mode 100644
index 000000000..95b99d8b5
--- /dev/null
+++ b/packages/components/nodes/retrievers/JinaRerankRetriever/JinaAI.svg
@@ -0,0 +1,5 @@
+
diff --git a/packages/components/nodes/retrievers/JinaRerankRetriever/JinaRerank.ts b/packages/components/nodes/retrievers/JinaRerankRetriever/JinaRerank.ts
new file mode 100644
index 000000000..f55ea53ab
--- /dev/null
+++ b/packages/components/nodes/retrievers/JinaRerankRetriever/JinaRerank.ts
@@ -0,0 +1,51 @@
+import { Callbacks } from '@langchain/core/callbacks/manager'
+import { Document } from '@langchain/core/documents'
+import axios from 'axios'
+import { BaseDocumentCompressor } from 'langchain/retrievers/document_compressors'
+
+export class JinaRerank extends BaseDocumentCompressor {
+ private jinaAPIKey: string
+ private readonly JINA_RERANK_API_URL = 'https://api.jina.ai/v1/rerank'
+ private model: string = 'jina-reranker-v2-base-multilingual'
+ private readonly topN: number
+
+ constructor(jinaAPIKey: string, model: string, topN: number) {
+ super()
+ this.jinaAPIKey = jinaAPIKey
+ this.model = model
+ this.topN = topN
+ }
+ async compressDocuments(
+ documents: Document>[],
+ query: string,
+ _?: Callbacks | undefined
+ ): Promise>[]> {
+ if (documents.length === 0) {
+ return []
+ }
+ const config = {
+ headers: {
+ Authorization: `Bearer ${this.jinaAPIKey}`,
+ 'Content-Type': 'application/json'
+ }
+ }
+ const data = {
+ model: this.model,
+ query: query,
+ documents: documents.map((doc) => doc.pageContent),
+ top_n: this.topN
+ }
+ try {
+ let returnedDocs = await axios.post(this.JINA_RERANK_API_URL, data, config)
+ const finalResults: Document>[] = []
+ returnedDocs.data.results.forEach((result: any) => {
+ const doc = documents[result.index]
+ doc.metadata.relevance_score = result.relevance_score
+ finalResults.push(doc)
+ })
+ return finalResults
+ } catch (error) {
+ return documents
+ }
+ }
+}
diff --git a/packages/components/nodes/retrievers/JinaRerankRetriever/JinaRerankRetriever.ts b/packages/components/nodes/retrievers/JinaRerankRetriever/JinaRerankRetriever.ts
new file mode 100644
index 000000000..3160f3f7a
--- /dev/null
+++ b/packages/components/nodes/retrievers/JinaRerankRetriever/JinaRerankRetriever.ts
@@ -0,0 +1,129 @@
+import { BaseRetriever } from '@langchain/core/retrievers'
+import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression'
+import { getCredentialData, getCredentialParam, handleEscapeCharacters } from '../../../src'
+import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/Interface'
+import { JinaRerank } from './JinaRerank'
+
+class JinaRerankRetriever_Retrievers implements INode {
+ label: string
+ name: string
+ version: number
+ description: string
+ type: string
+ icon: string
+ category: string
+ baseClasses: string[]
+ inputs: INodeParams[]
+ credential: INodeParams
+ badge: string
+ outputs: INodeOutputsValue[]
+
+ constructor() {
+ this.label = 'Jina AI Rerank Retriever'
+ this.name = 'JinaRerankRetriever'
+ this.version = 1.0
+ this.type = 'JinaRerankRetriever'
+ this.icon = 'JinaAI.svg'
+ this.category = 'Retrievers'
+ this.description = 'Jina AI Rerank indexes the documents from most to least semantically relevant to the query.'
+ this.baseClasses = [this.type, 'BaseRetriever']
+ this.credential = {
+ label: 'Connect Credential',
+ name: 'credential',
+ type: 'credential',
+ credentialNames: ['jinaAIApi']
+ }
+ this.inputs = [
+ {
+ label: 'Vector Store Retriever',
+ name: 'baseRetriever',
+ type: 'VectorStoreRetriever'
+ },
+ {
+ label: 'Model Name',
+ name: 'model',
+ type: 'options',
+ options: [
+ {
+ label: 'jina-reranker-v2-base-multilingual',
+ name: 'jina-reranker-v2-base-multilingual'
+ },
+ {
+ label: 'jina-colbert-v2',
+ name: 'jina-colbert-v2'
+ }
+ ],
+ default: 'jina-reranker-v2-base-multilingual',
+ optional: true
+ },
+ {
+ label: 'Query',
+ name: 'query',
+ type: 'string',
+ description: 'Query to retrieve documents from retriever. If not specified, user question will be used',
+ optional: true,
+ acceptVariable: true
+ },
+ {
+ label: 'Top N',
+ name: 'topN',
+ description: 'Number of top results to fetch. Default to 4',
+ placeholder: '4',
+ default: 4,
+ type: 'number',
+ additionalParams: true,
+ optional: true
+ }
+ ]
+ this.outputs = [
+ {
+ label: 'Jina AI Rerank Retriever',
+ name: 'retriever',
+ baseClasses: this.baseClasses
+ },
+ {
+ label: 'Document',
+ name: 'document',
+ description: 'Array of document objects containing metadata and pageContent',
+ baseClasses: ['Document', 'json']
+ },
+ {
+ label: 'Text',
+ name: 'text',
+ description: 'Concatenated string from pageContent of documents',
+ baseClasses: ['string', 'json']
+ }
+ ]
+ }
+
+ async init(nodeData: INodeData, input: string, options: ICommonObject): Promise {
+ const baseRetriever = nodeData.inputs?.baseRetriever as BaseRetriever
+ const model = nodeData.inputs?.model as string
+ const query = nodeData.inputs?.query as string
+ const credentialData = await getCredentialData(nodeData.credential ?? '', options)
+ const jinaApiKey = getCredentialParam('jinaAIAPIKey', credentialData, nodeData)
+ const topN = nodeData.inputs?.topN ? parseFloat(nodeData.inputs?.topN as string) : 4
+ const output = nodeData.outputs?.output as string
+
+ const jinaCompressor = new JinaRerank(jinaApiKey, model, topN)
+
+ const retriever = new ContextualCompressionRetriever({
+ baseCompressor: jinaCompressor,
+ baseRetriever: baseRetriever
+ })
+
+ if (output === 'retriever') return retriever
+ else if (output === 'document') return await retriever.invoke(query ? query : input)
+ else if (output === 'text') {
+ const docs = await retriever.invoke(query ? query : input)
+ let finaltext = ''
+ for (const doc of docs) finaltext += `${doc.pageContent}\n`
+
+ return handleEscapeCharacters(finaltext, false)
+ }
+
+ return retriever
+ }
+}
+
+module.exports = { nodeClass: JinaRerankRetriever_Retrievers }