update multi chains

This commit is contained in:
Henry 2023-06-14 15:47:33 +01:00
parent 74939c187a
commit e50c065aca
3 changed files with 34 additions and 10 deletions

View File

@ -49,9 +49,12 @@ class MultiPromptChain_Chains implements INode {
promptTemplates.push(prompt.systemMessage) promptTemplates.push(prompt.systemMessage)
} }
const chain = MultiPromptChain.fromPrompts(model, promptNames, promptDescriptions, promptTemplates, undefined, { const chain = MultiPromptChain.fromLLMAndPrompts(model, {
verbose: process.env.DEBUG === 'true' ? true : false promptNames,
} as any) promptDescriptions,
promptTemplates,
llmChainOpts: { verbose: process.env.DEBUG === 'true' ? true : false }
})
return chain return chain
} }

View File

@ -32,6 +32,12 @@ class MultiRetrievalQAChain_Chains implements INode {
name: 'vectorStoreRetriever', name: 'vectorStoreRetriever',
type: 'VectorStoreRetriever', type: 'VectorStoreRetriever',
list: true list: true
},
{
label: 'Return Source Documents',
name: 'returnSourceDocuments',
type: 'boolean',
optional: true
} }
] ]
} }
@ -39,6 +45,8 @@ class MultiRetrievalQAChain_Chains implements INode {
async init(nodeData: INodeData): Promise<any> { async init(nodeData: INodeData): Promise<any> {
const model = nodeData.inputs?.model as BaseLanguageModel const model = nodeData.inputs?.model as BaseLanguageModel
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as VectorStoreRetriever[] const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as VectorStoreRetriever[]
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const retrieverNames = [] const retrieverNames = []
const retrieverDescriptions = [] const retrieverDescriptions = []
const retrievers = [] const retrievers = []
@ -49,23 +57,29 @@ class MultiRetrievalQAChain_Chains implements INode {
retrievers.push(vs.vectorStore.asRetriever((vs.vectorStore as any).k ?? 4)) retrievers.push(vs.vectorStore.asRetriever((vs.vectorStore as any).k ?? 4))
} }
const chain = MultiRetrievalQAChain.fromRetrievers(model, retrieverNames, retrieverDescriptions, retrievers, undefined, { const chain = MultiRetrievalQAChain.fromLLMAndRetrievers(model, {
verbose: process.env.DEBUG === 'true' ? true : false retrieverNames,
} as any) retrieverDescriptions,
retrievers,
retrievalQAChainOpts: { verbose: process.env.DEBUG === 'true' ? true : false, returnSourceDocuments }
})
return chain return chain
} }
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> { async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
const chain = nodeData.instance as MultiRetrievalQAChain const chain = nodeData.instance as MultiRetrievalQAChain
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const obj = { input } const obj = { input }
if (options.socketIO && options.socketIOClientId) { if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, 2) const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, 2, returnSourceDocuments)
const res = await chain.call(obj, [handler]) const res = await chain.call(obj, [handler])
if (res.text && res.sourceDocuments) return res
return res?.text return res?.text
} else { } else {
const res = await chain.call(obj) const res = await chain.call(obj)
if (res.text && res.sourceDocuments) return res
return res?.text return res?.text
} }
} }

View File

@ -84,7 +84,14 @@
"baseClasses": ["MultiRetrievalQAChain", "MultiRouteChain", "BaseChain", "BaseLangChain"], "baseClasses": ["MultiRetrievalQAChain", "MultiRouteChain", "BaseChain", "BaseLangChain"],
"category": "Chains", "category": "Chains",
"description": "QA Chain that automatically picks an appropriate vector store from multiple retrievers", "description": "QA Chain that automatically picks an appropriate vector store from multiple retrievers",
"inputParams": [], "inputParams": [
{
"label": "Return Source Documents",
"name": "returnSourceDocuments",
"type": "boolean",
"optional": true
}
],
"inputAnchors": [ "inputAnchors": [
{ {
"label": "Language Model", "label": "Language Model",