enable faiss streaming

This commit is contained in:
Henry 2023-08-17 18:38:38 +01:00
parent e9adc15ff2
commit ab60a6bda1
4 changed files with 37 additions and 28 deletions

View File

@ -2,6 +2,7 @@ import { INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/I
import { FaissStore } from 'langchain/vectorstores/faiss'
import { Embeddings } from 'langchain/embeddings/base'
import { getBaseClasses } from '../../../src/utils'
import { Document } from 'langchain/document'
class Faiss_Existing_VectorStores implements INode {
label: string
@ -70,6 +71,23 @@ class Faiss_Existing_VectorStores implements INode {
const vectorStore = await FaissStore.load(basePath, embeddings)
// Avoid illegal invocation error
vectorStore.similaritySearchVectorWithScore = async (query: number[], k: number) => {
const index = vectorStore.index
if (k > index.ntotal()) {
const total = index.ntotal()
console.warn(`k (${k}) is greater than the number of elements in the index (${total}), setting k to ${total}`)
k = total
}
const result = index.search(query, k)
return result.labels.map((id, index) => {
const uuid = vectorStore._mapping[id]
return [vectorStore.docstore.search(uuid), result.distances[index]] as [Document, number]
})
}
if (output === 'retriever') {
const retriever = vectorStore.asRetriever(k)
return retriever

View File

@ -86,6 +86,23 @@ class FaissUpsert_VectorStores implements INode {
const vectorStore = await FaissStore.fromDocuments(finalDocs, embeddings)
await vectorStore.save(basePath)
// Avoid illegal invocation error
vectorStore.similaritySearchVectorWithScore = async (query: number[], k: number) => {
const index = vectorStore.index
if (k > index.ntotal()) {
const total = index.ntotal()
console.warn(`k (${k}) is greater than the number of elements in the index (${total}), setting k to ${total}`)
k = total
}
const result = index.search(query, k)
return result.labels.map((id, index) => {
const uuid = vectorStore._mapping[id]
return [vectorStore.docstore.search(uuid), result.distances[index]] as [Document, number]
})
}
if (output === 'retriever') {
const retriever = vectorStore.asRetriever(k)
return retriever

View File

@ -36,7 +36,6 @@ import {
isSameOverrideConfig,
replaceAllAPIKeys,
isFlowValidForStream,
isVectorStoreFaiss,
databaseEntities,
getApiKey,
transformToCredentialEntity,
@ -911,7 +910,6 @@ export class App {
const nodeModule = await import(nodeInstanceFilePath)
const nodeInstance = new nodeModule.nodeClass()
isStreamValid = isStreamValid && !isVectorStoreFaiss(nodeToExecuteData)
logger.debug(`[server]: Running ${nodeToExecuteData.label} (${nodeToExecuteData.id})`)
if (nodeToExecuteData.instance) checkMemorySessionId(nodeToExecuteData.instance, chatId)

View File

@ -18,7 +18,7 @@ import {
IComponentCredentials,
ICredentialReqBody
} from '../Interface'
import { cloneDeep, get, omit, merge, isEqual } from 'lodash'
import { cloneDeep, get, isEqual } from 'lodash'
import {
ICommonObject,
getInputVariables,
@ -393,25 +393,6 @@ export const getVariableValue = (
return returnVal
}
/**
* Temporarily disable streaming if vectorStore is Faiss
* @param {INodeData} flowNodeData
* @returns {boolean}
*/
export const isVectorStoreFaiss = (flowNodeData: INodeData) => {
if (flowNodeData.inputs && flowNodeData.inputs.vectorStoreRetriever) {
const vectorStoreRetriever = flowNodeData.inputs.vectorStoreRetriever
if (typeof vectorStoreRetriever === 'string' && vectorStoreRetriever.includes('faiss')) return true
if (
typeof vectorStoreRetriever === 'object' &&
vectorStoreRetriever.vectorStore &&
vectorStoreRetriever.vectorStore.constructor.name === 'FaissStore'
)
return true
}
return false
}
/**
* Loop through each inputs and resolve variable if neccessary
* @param {INodeData} reactFlowNodeData
@ -426,11 +407,6 @@ export const resolveVariables = (
chatHistory: IMessage[]
): INodeData => {
let flowNodeData = cloneDeep(reactFlowNodeData)
if (reactFlowNodeData.instance && isVectorStoreFaiss(reactFlowNodeData)) {
// omit and merge because cloneDeep of instance gives "Illegal invocation" Exception
const flowNodeDataWithoutInstance = cloneDeep(omit(reactFlowNodeData, ['instance']))
flowNodeData = merge(flowNodeDataWithoutInstance, { instance: reactFlowNodeData.instance })
}
const types = 'inputs'
const getParamValues = (paramsObj: ICommonObject) => {
@ -819,7 +795,7 @@ export const isFlowValidForStream = (reactFlowNodes: IReactFlowNode[], endingNod
isValidChainOrAgent = whitelistAgents.includes(endingNodeData.name)
}
return isChatOrLLMsExist && isValidChainOrAgent && !isVectorStoreFaiss(endingNodeData)
return isChatOrLLMsExist && isValidChainOrAgent
}
/**