enable faiss streaming
This commit is contained in:
parent
e9adc15ff2
commit
ab60a6bda1
|
|
@ -2,6 +2,7 @@ import { INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/I
|
||||||
import { FaissStore } from 'langchain/vectorstores/faiss'
|
import { FaissStore } from 'langchain/vectorstores/faiss'
|
||||||
import { Embeddings } from 'langchain/embeddings/base'
|
import { Embeddings } from 'langchain/embeddings/base'
|
||||||
import { getBaseClasses } from '../../../src/utils'
|
import { getBaseClasses } from '../../../src/utils'
|
||||||
|
import { Document } from 'langchain/document'
|
||||||
|
|
||||||
class Faiss_Existing_VectorStores implements INode {
|
class Faiss_Existing_VectorStores implements INode {
|
||||||
label: string
|
label: string
|
||||||
|
|
@ -70,6 +71,23 @@ class Faiss_Existing_VectorStores implements INode {
|
||||||
|
|
||||||
const vectorStore = await FaissStore.load(basePath, embeddings)
|
const vectorStore = await FaissStore.load(basePath, embeddings)
|
||||||
|
|
||||||
|
// Avoid illegal invocation error
|
||||||
|
vectorStore.similaritySearchVectorWithScore = async (query: number[], k: number) => {
|
||||||
|
const index = vectorStore.index
|
||||||
|
|
||||||
|
if (k > index.ntotal()) {
|
||||||
|
const total = index.ntotal()
|
||||||
|
console.warn(`k (${k}) is greater than the number of elements in the index (${total}), setting k to ${total}`)
|
||||||
|
k = total
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = index.search(query, k)
|
||||||
|
return result.labels.map((id, index) => {
|
||||||
|
const uuid = vectorStore._mapping[id]
|
||||||
|
return [vectorStore.docstore.search(uuid), result.distances[index]] as [Document, number]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if (output === 'retriever') {
|
if (output === 'retriever') {
|
||||||
const retriever = vectorStore.asRetriever(k)
|
const retriever = vectorStore.asRetriever(k)
|
||||||
return retriever
|
return retriever
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,23 @@ class FaissUpsert_VectorStores implements INode {
|
||||||
const vectorStore = await FaissStore.fromDocuments(finalDocs, embeddings)
|
const vectorStore = await FaissStore.fromDocuments(finalDocs, embeddings)
|
||||||
await vectorStore.save(basePath)
|
await vectorStore.save(basePath)
|
||||||
|
|
||||||
|
// Avoid illegal invocation error
|
||||||
|
vectorStore.similaritySearchVectorWithScore = async (query: number[], k: number) => {
|
||||||
|
const index = vectorStore.index
|
||||||
|
|
||||||
|
if (k > index.ntotal()) {
|
||||||
|
const total = index.ntotal()
|
||||||
|
console.warn(`k (${k}) is greater than the number of elements in the index (${total}), setting k to ${total}`)
|
||||||
|
k = total
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = index.search(query, k)
|
||||||
|
return result.labels.map((id, index) => {
|
||||||
|
const uuid = vectorStore._mapping[id]
|
||||||
|
return [vectorStore.docstore.search(uuid), result.distances[index]] as [Document, number]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if (output === 'retriever') {
|
if (output === 'retriever') {
|
||||||
const retriever = vectorStore.asRetriever(k)
|
const retriever = vectorStore.asRetriever(k)
|
||||||
return retriever
|
return retriever
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,6 @@ import {
|
||||||
isSameOverrideConfig,
|
isSameOverrideConfig,
|
||||||
replaceAllAPIKeys,
|
replaceAllAPIKeys,
|
||||||
isFlowValidForStream,
|
isFlowValidForStream,
|
||||||
isVectorStoreFaiss,
|
|
||||||
databaseEntities,
|
databaseEntities,
|
||||||
getApiKey,
|
getApiKey,
|
||||||
transformToCredentialEntity,
|
transformToCredentialEntity,
|
||||||
|
|
@ -911,7 +910,6 @@ export class App {
|
||||||
const nodeModule = await import(nodeInstanceFilePath)
|
const nodeModule = await import(nodeInstanceFilePath)
|
||||||
const nodeInstance = new nodeModule.nodeClass()
|
const nodeInstance = new nodeModule.nodeClass()
|
||||||
|
|
||||||
isStreamValid = isStreamValid && !isVectorStoreFaiss(nodeToExecuteData)
|
|
||||||
logger.debug(`[server]: Running ${nodeToExecuteData.label} (${nodeToExecuteData.id})`)
|
logger.debug(`[server]: Running ${nodeToExecuteData.label} (${nodeToExecuteData.id})`)
|
||||||
|
|
||||||
if (nodeToExecuteData.instance) checkMemorySessionId(nodeToExecuteData.instance, chatId)
|
if (nodeToExecuteData.instance) checkMemorySessionId(nodeToExecuteData.instance, chatId)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import {
|
||||||
IComponentCredentials,
|
IComponentCredentials,
|
||||||
ICredentialReqBody
|
ICredentialReqBody
|
||||||
} from '../Interface'
|
} from '../Interface'
|
||||||
import { cloneDeep, get, omit, merge, isEqual } from 'lodash'
|
import { cloneDeep, get, isEqual } from 'lodash'
|
||||||
import {
|
import {
|
||||||
ICommonObject,
|
ICommonObject,
|
||||||
getInputVariables,
|
getInputVariables,
|
||||||
|
|
@ -393,25 +393,6 @@ export const getVariableValue = (
|
||||||
return returnVal
|
return returnVal
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Temporarily disable streaming if vectorStore is Faiss
|
|
||||||
* @param {INodeData} flowNodeData
|
|
||||||
* @returns {boolean}
|
|
||||||
*/
|
|
||||||
export const isVectorStoreFaiss = (flowNodeData: INodeData) => {
|
|
||||||
if (flowNodeData.inputs && flowNodeData.inputs.vectorStoreRetriever) {
|
|
||||||
const vectorStoreRetriever = flowNodeData.inputs.vectorStoreRetriever
|
|
||||||
if (typeof vectorStoreRetriever === 'string' && vectorStoreRetriever.includes('faiss')) return true
|
|
||||||
if (
|
|
||||||
typeof vectorStoreRetriever === 'object' &&
|
|
||||||
vectorStoreRetriever.vectorStore &&
|
|
||||||
vectorStoreRetriever.vectorStore.constructor.name === 'FaissStore'
|
|
||||||
)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loop through each inputs and resolve variable if neccessary
|
* Loop through each inputs and resolve variable if neccessary
|
||||||
* @param {INodeData} reactFlowNodeData
|
* @param {INodeData} reactFlowNodeData
|
||||||
|
|
@ -426,11 +407,6 @@ export const resolveVariables = (
|
||||||
chatHistory: IMessage[]
|
chatHistory: IMessage[]
|
||||||
): INodeData => {
|
): INodeData => {
|
||||||
let flowNodeData = cloneDeep(reactFlowNodeData)
|
let flowNodeData = cloneDeep(reactFlowNodeData)
|
||||||
if (reactFlowNodeData.instance && isVectorStoreFaiss(reactFlowNodeData)) {
|
|
||||||
// omit and merge because cloneDeep of instance gives "Illegal invocation" Exception
|
|
||||||
const flowNodeDataWithoutInstance = cloneDeep(omit(reactFlowNodeData, ['instance']))
|
|
||||||
flowNodeData = merge(flowNodeDataWithoutInstance, { instance: reactFlowNodeData.instance })
|
|
||||||
}
|
|
||||||
const types = 'inputs'
|
const types = 'inputs'
|
||||||
|
|
||||||
const getParamValues = (paramsObj: ICommonObject) => {
|
const getParamValues = (paramsObj: ICommonObject) => {
|
||||||
|
|
@ -819,7 +795,7 @@ export const isFlowValidForStream = (reactFlowNodes: IReactFlowNode[], endingNod
|
||||||
isValidChainOrAgent = whitelistAgents.includes(endingNodeData.name)
|
isValidChainOrAgent = whitelistAgents.includes(endingNodeData.name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return isChatOrLLMsExist && isValidChainOrAgent && !isVectorStoreFaiss(endingNodeData)
|
return isChatOrLLMsExist && isValidChainOrAgent
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue