diff --git a/packages/components/nodes/chains/MultiPromptChain/MultiPromptChain.ts b/packages/components/nodes/chains/MultiPromptChain/MultiPromptChain.ts index a7b00f463..c74e3257f 100644 --- a/packages/components/nodes/chains/MultiPromptChain/MultiPromptChain.ts +++ b/packages/components/nodes/chains/MultiPromptChain/MultiPromptChain.ts @@ -1,6 +1,6 @@ import { BaseLanguageModel } from 'langchain/base_language' -import { INode, INodeData, INodeParams, PromptRetriever } from '../../../src/Interface' -import { getBaseClasses } from '../../../src/utils' +import { ICommonObject, INode, INodeData, INodeParams, PromptRetriever } from '../../../src/Interface' +import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { MultiPromptChain } from 'langchain/chains' class MultiPromptChain_Chains implements INode { @@ -56,12 +56,18 @@ class MultiPromptChain_Chains implements INode { return chain } - async run(nodeData: INodeData, input: string): Promise { + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const chain = nodeData.instance as MultiPromptChain + const obj = { input } - const res = await chain.call({ input }) - - return res?.text + if (options.socketIO && options.socketIOClientId) { + const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + const res = await chain.call(obj, [handler]) + return res?.text + } else { + const res = await chain.call(obj) + return res?.text + } } } diff --git a/packages/components/nodes/chains/MultiRetrievalQAChain/MultiRetrievalQAChain.ts b/packages/components/nodes/chains/MultiRetrievalQAChain/MultiRetrievalQAChain.ts index 7b8778f10..214db5092 100644 --- a/packages/components/nodes/chains/MultiRetrievalQAChain/MultiRetrievalQAChain.ts +++ b/packages/components/nodes/chains/MultiRetrievalQAChain/MultiRetrievalQAChain.ts @@ -1,6 +1,6 @@ import { BaseLanguageModel } from 'langchain/base_language' -import { INode, INodeData, INodeParams, VectorStoreRetriever } from '../../../src/Interface' -import { getBaseClasses } from '../../../src/utils' +import { ICommonObject, INode, INodeData, INodeParams, VectorStoreRetriever } from '../../../src/Interface' +import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { MultiRetrievalQAChain } from 'langchain/chains' class MultiRetrievalQAChain_Chains implements INode { @@ -56,12 +56,18 @@ class MultiRetrievalQAChain_Chains implements INode { return chain } - async run(nodeData: INodeData, input: string): Promise { + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const chain = nodeData.instance as MultiRetrievalQAChain + const obj = { input } - const res = await chain.call({ input }) - - return res?.text + if (options.socketIO && options.socketIOClientId) { + const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + const res = await chain.call(obj, [handler]) + return res?.text + } else { + const res = await chain.call(obj) + return res?.text + } } }