add streaming handler to MultiChains

This commit is contained in:
Henry 2023-05-26 12:35:31 +01:00
parent 2968dccd83
commit a4727c11f0
2 changed files with 24 additions and 12 deletions

View File

@ -1,6 +1,6 @@
import { BaseLanguageModel } from 'langchain/base_language'
import { INode, INodeData, INodeParams, PromptRetriever } from '../../../src/Interface'
import { getBaseClasses } from '../../../src/utils'
import { ICommonObject, INode, INodeData, INodeParams, PromptRetriever } from '../../../src/Interface'
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
import { MultiPromptChain } from 'langchain/chains'
class MultiPromptChain_Chains implements INode {
@ -56,12 +56,18 @@ class MultiPromptChain_Chains implements INode {
return chain
}
async run(nodeData: INodeData, input: string): Promise<string> {
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
const chain = nodeData.instance as MultiPromptChain
const obj = { input }
const res = await chain.call({ input })
if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
const res = await chain.call(obj, [handler])
return res?.text
} else {
const res = await chain.call(obj)
return res?.text
}
}
}

View File

@ -1,6 +1,6 @@
import { BaseLanguageModel } from 'langchain/base_language'
import { INode, INodeData, INodeParams, VectorStoreRetriever } from '../../../src/Interface'
import { getBaseClasses } from '../../../src/utils'
import { ICommonObject, INode, INodeData, INodeParams, VectorStoreRetriever } from '../../../src/Interface'
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
import { MultiRetrievalQAChain } from 'langchain/chains'
class MultiRetrievalQAChain_Chains implements INode {
@ -56,12 +56,18 @@ class MultiRetrievalQAChain_Chains implements INode {
return chain
}
async run(nodeData: INodeData, input: string): Promise<string> {
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
const chain = nodeData.instance as MultiRetrievalQAChain
const obj = { input }
const res = await chain.call({ input })
if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
const res = await chain.call(obj, [handler])
return res?.text
} else {
const res = await chain.call(obj)
return res?.text
}
}
}