import { DynamicTool, DynamicToolInput } from '@langchain/core/tools' import { BaseChain } from 'langchain/chains' import { handleEscapeCharacters } from '../../../src/utils' import { CustomChainHandler } from '../../../src' export interface ChainToolInput extends Omit { chain: BaseChain } export class ChainTool extends DynamicTool { chain: BaseChain constructor({ chain, ...rest }: ChainToolInput) { super({ ...rest, func: async (input, runManager) => { // prevent sending SSE events of the sub-chain const sseStreamer = runManager?.handlers.find((handler) => handler instanceof CustomChainHandler)?.sseStreamer if (runManager) { const callbacks = runManager.handlers for (let i = 0; i < callbacks.length; i += 1) { if (callbacks[i] instanceof CustomChainHandler) { ;(callbacks[i] as any).sseStreamer = undefined } } } if ((chain as any).prompt && (chain as any).prompt.promptValues) { const promptValues = handleEscapeCharacters((chain as any).prompt.promptValues, true) const values = await chain.call(promptValues, runManager?.getChild()) if (runManager && sseStreamer) { const callbacks = runManager.handlers for (let i = 0; i < callbacks.length; i += 1) { if (callbacks[i] instanceof CustomChainHandler) { ;(callbacks[i] as any).sseStreamer = sseStreamer } } } return values?.text } const values = chain.run(input, runManager?.getChild()) if (runManager && sseStreamer) { const callbacks = runManager.handlers for (let i = 0; i < callbacks.length; i += 1) { if (callbacks[i] instanceof CustomChainHandler) { ;(callbacks[i] as any).sseStreamer = sseStreamer } } } return values } }) this.chain = chain } }