diff --git a/packages/components/nodes/chains/MultiPromptChain/MultiPromptChain.ts b/packages/components/nodes/chains/MultiPromptChain/MultiPromptChain.ts index 659641f8c..189f41f72 100644 --- a/packages/components/nodes/chains/MultiPromptChain/MultiPromptChain.ts +++ b/packages/components/nodes/chains/MultiPromptChain/MultiPromptChain.ts @@ -49,9 +49,12 @@ class MultiPromptChain_Chains implements INode { promptTemplates.push(prompt.systemMessage) } - const chain = MultiPromptChain.fromPrompts(model, promptNames, promptDescriptions, promptTemplates, undefined, { - verbose: process.env.DEBUG === 'true' ? true : false - } as any) + const chain = MultiPromptChain.fromLLMAndPrompts(model, { + promptNames, + promptDescriptions, + promptTemplates, + llmChainOpts: { verbose: process.env.DEBUG === 'true' ? true : false } + }) return chain } diff --git a/packages/components/nodes/chains/MultiRetrievalQAChain/MultiRetrievalQAChain.ts b/packages/components/nodes/chains/MultiRetrievalQAChain/MultiRetrievalQAChain.ts index b18ac8677..b3575a930 100644 --- a/packages/components/nodes/chains/MultiRetrievalQAChain/MultiRetrievalQAChain.ts +++ b/packages/components/nodes/chains/MultiRetrievalQAChain/MultiRetrievalQAChain.ts @@ -32,6 +32,12 @@ class MultiRetrievalQAChain_Chains implements INode { name: 'vectorStoreRetriever', type: 'VectorStoreRetriever', list: true + }, + { + label: 'Return Source Documents', + name: 'returnSourceDocuments', + type: 'boolean', + optional: true } ] } @@ -39,6 +45,8 @@ class MultiRetrievalQAChain_Chains implements INode { async init(nodeData: INodeData): Promise { const model = nodeData.inputs?.model as BaseLanguageModel const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as VectorStoreRetriever[] + const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean + const retrieverNames = [] const retrieverDescriptions = [] const retrievers = [] @@ -49,23 +57,29 @@ class MultiRetrievalQAChain_Chains implements INode { retrievers.push(vs.vectorStore.asRetriever((vs.vectorStore as any).k ?? 4)) } - const chain = MultiRetrievalQAChain.fromRetrievers(model, retrieverNames, retrieverDescriptions, retrievers, undefined, { - verbose: process.env.DEBUG === 'true' ? true : false - } as any) - + const chain = MultiRetrievalQAChain.fromLLMAndRetrievers(model, { + retrieverNames, + retrieverDescriptions, + retrievers, + retrievalQAChainOpts: { verbose: process.env.DEBUG === 'true' ? true : false, returnSourceDocuments } + }) return chain } - async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const chain = nodeData.instance as MultiRetrievalQAChain + const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean + const obj = { input } if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, 2) + const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, 2, returnSourceDocuments) const res = await chain.call(obj, [handler]) + if (res.text && res.sourceDocuments) return res return res?.text } else { const res = await chain.call(obj) + if (res.text && res.sourceDocuments) return res return res?.text } } diff --git a/packages/server/marketplaces/Multi Retrieval QA Chain.json b/packages/server/marketplaces/Multi Retrieval QA Chain.json index cdb468d1f..f3cd1fcc1 100644 --- a/packages/server/marketplaces/Multi Retrieval QA Chain.json +++ b/packages/server/marketplaces/Multi Retrieval QA Chain.json @@ -84,7 +84,14 @@ "baseClasses": ["MultiRetrievalQAChain", "MultiRouteChain", "BaseChain", "BaseLangChain"], "category": "Chains", "description": "QA Chain that automatically picks an appropriate vector store from multiple retrievers", - "inputParams": [], + "inputParams": [ + { + "label": "Return Source Documents", + "name": "returnSourceDocuments", + "type": "boolean", + "optional": true + } + ], "inputAnchors": [ { "label": "Language Model",