enable faiss streaming
This commit is contained in:
parent
e9adc15ff2
commit
ab60a6bda1
|
|
@ -2,6 +2,7 @@ import { INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/I
|
|||
import { FaissStore } from 'langchain/vectorstores/faiss'
|
||||
import { Embeddings } from 'langchain/embeddings/base'
|
||||
import { getBaseClasses } from '../../../src/utils'
|
||||
import { Document } from 'langchain/document'
|
||||
|
||||
class Faiss_Existing_VectorStores implements INode {
|
||||
label: string
|
||||
|
|
@ -70,6 +71,23 @@ class Faiss_Existing_VectorStores implements INode {
|
|||
|
||||
const vectorStore = await FaissStore.load(basePath, embeddings)
|
||||
|
||||
// Avoid illegal invocation error
|
||||
vectorStore.similaritySearchVectorWithScore = async (query: number[], k: number) => {
|
||||
const index = vectorStore.index
|
||||
|
||||
if (k > index.ntotal()) {
|
||||
const total = index.ntotal()
|
||||
console.warn(`k (${k}) is greater than the number of elements in the index (${total}), setting k to ${total}`)
|
||||
k = total
|
||||
}
|
||||
|
||||
const result = index.search(query, k)
|
||||
return result.labels.map((id, index) => {
|
||||
const uuid = vectorStore._mapping[id]
|
||||
return [vectorStore.docstore.search(uuid), result.distances[index]] as [Document, number]
|
||||
})
|
||||
}
|
||||
|
||||
if (output === 'retriever') {
|
||||
const retriever = vectorStore.asRetriever(k)
|
||||
return retriever
|
||||
|
|
|
|||
|
|
@ -86,6 +86,23 @@ class FaissUpsert_VectorStores implements INode {
|
|||
const vectorStore = await FaissStore.fromDocuments(finalDocs, embeddings)
|
||||
await vectorStore.save(basePath)
|
||||
|
||||
// Avoid illegal invocation error
|
||||
vectorStore.similaritySearchVectorWithScore = async (query: number[], k: number) => {
|
||||
const index = vectorStore.index
|
||||
|
||||
if (k > index.ntotal()) {
|
||||
const total = index.ntotal()
|
||||
console.warn(`k (${k}) is greater than the number of elements in the index (${total}), setting k to ${total}`)
|
||||
k = total
|
||||
}
|
||||
|
||||
const result = index.search(query, k)
|
||||
return result.labels.map((id, index) => {
|
||||
const uuid = vectorStore._mapping[id]
|
||||
return [vectorStore.docstore.search(uuid), result.distances[index]] as [Document, number]
|
||||
})
|
||||
}
|
||||
|
||||
if (output === 'retriever') {
|
||||
const retriever = vectorStore.asRetriever(k)
|
||||
return retriever
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ import {
|
|||
isSameOverrideConfig,
|
||||
replaceAllAPIKeys,
|
||||
isFlowValidForStream,
|
||||
isVectorStoreFaiss,
|
||||
databaseEntities,
|
||||
getApiKey,
|
||||
transformToCredentialEntity,
|
||||
|
|
@ -911,7 +910,6 @@ export class App {
|
|||
const nodeModule = await import(nodeInstanceFilePath)
|
||||
const nodeInstance = new nodeModule.nodeClass()
|
||||
|
||||
isStreamValid = isStreamValid && !isVectorStoreFaiss(nodeToExecuteData)
|
||||
logger.debug(`[server]: Running ${nodeToExecuteData.label} (${nodeToExecuteData.id})`)
|
||||
|
||||
if (nodeToExecuteData.instance) checkMemorySessionId(nodeToExecuteData.instance, chatId)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import {
|
|||
IComponentCredentials,
|
||||
ICredentialReqBody
|
||||
} from '../Interface'
|
||||
import { cloneDeep, get, omit, merge, isEqual } from 'lodash'
|
||||
import { cloneDeep, get, isEqual } from 'lodash'
|
||||
import {
|
||||
ICommonObject,
|
||||
getInputVariables,
|
||||
|
|
@ -393,25 +393,6 @@ export const getVariableValue = (
|
|||
return returnVal
|
||||
}
|
||||
|
||||
/**
|
||||
* Temporarily disable streaming if vectorStore is Faiss
|
||||
* @param {INodeData} flowNodeData
|
||||
* @returns {boolean}
|
||||
*/
|
||||
export const isVectorStoreFaiss = (flowNodeData: INodeData) => {
|
||||
if (flowNodeData.inputs && flowNodeData.inputs.vectorStoreRetriever) {
|
||||
const vectorStoreRetriever = flowNodeData.inputs.vectorStoreRetriever
|
||||
if (typeof vectorStoreRetriever === 'string' && vectorStoreRetriever.includes('faiss')) return true
|
||||
if (
|
||||
typeof vectorStoreRetriever === 'object' &&
|
||||
vectorStoreRetriever.vectorStore &&
|
||||
vectorStoreRetriever.vectorStore.constructor.name === 'FaissStore'
|
||||
)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Loop through each inputs and resolve variable if neccessary
|
||||
* @param {INodeData} reactFlowNodeData
|
||||
|
|
@ -426,11 +407,6 @@ export const resolveVariables = (
|
|||
chatHistory: IMessage[]
|
||||
): INodeData => {
|
||||
let flowNodeData = cloneDeep(reactFlowNodeData)
|
||||
if (reactFlowNodeData.instance && isVectorStoreFaiss(reactFlowNodeData)) {
|
||||
// omit and merge because cloneDeep of instance gives "Illegal invocation" Exception
|
||||
const flowNodeDataWithoutInstance = cloneDeep(omit(reactFlowNodeData, ['instance']))
|
||||
flowNodeData = merge(flowNodeDataWithoutInstance, { instance: reactFlowNodeData.instance })
|
||||
}
|
||||
const types = 'inputs'
|
||||
|
||||
const getParamValues = (paramsObj: ICommonObject) => {
|
||||
|
|
@ -819,7 +795,7 @@ export const isFlowValidForStream = (reactFlowNodes: IReactFlowNode[], endingNod
|
|||
isValidChainOrAgent = whitelistAgents.includes(endingNodeData.name)
|
||||
}
|
||||
|
||||
return isChatOrLLMsExist && isValidChainOrAgent && !isVectorStoreFaiss(endingNodeData)
|
||||
return isChatOrLLMsExist && isValidChainOrAgent
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Reference in New Issue