From 3696c4517a27cd190d43c67da029deac09f474ce Mon Sep 17 00:00:00 2001 From: vinodkiran Date: Sat, 28 Oct 2023 09:09:29 +0530 Subject: [PATCH] Removal of the custom output parsing. --- .../nodes/chains/LLMChain/LLMChain.ts | 42 +++++++++++-------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/packages/components/nodes/chains/LLMChain/LLMChain.ts b/packages/components/nodes/chains/LLMChain/LLMChain.ts index 6d041e504..d1860f1b4 100644 --- a/packages/components/nodes/chains/LLMChain/LLMChain.ts +++ b/packages/components/nodes/chains/LLMChain/LLMChain.ts @@ -4,7 +4,8 @@ import { LLMChain } from 'langchain/chains' import { BaseLanguageModel } from 'langchain/base_language' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { BaseOutputParser } from 'langchain/schema/output_parser' -import { injectOutputParser, applyOutputParser } from '../../outputparsers/OutputParserHelpers' +import { injectOutputParser } from '../../outputparsers/OutputParserHelpers' +import { BaseLLMOutputParser } from 'langchain/schema/output_parser' class LLMChain_Chains implements INode { label: string @@ -28,16 +29,16 @@ class LLMChain_Chains implements INode { this.description = 'Chain to run queries against LLMs' this.baseClasses = [this.type, ...getBaseClasses(LLMChain)] this.inputs = [ - { - label: 'Language Model', - name: 'model', - type: 'BaseLanguageModel' - }, { label: 'Prompt', name: 'prompt', type: 'BasePromptTemplate' }, + { + label: 'Language Model', + name: 'model', + type: 'BaseLanguageModel' + }, { label: 'Output Parser', name: 'outputParser', @@ -71,12 +72,18 @@ class LLMChain_Chains implements INode { const prompt = nodeData.inputs?.prompt const output = nodeData.outputs?.output as string const promptValues = prompt.promptValues as ICommonObject + const llmOutputParser = nodeData.inputs?.outputParser as BaseLLMOutputParser if (output === this.name) { - const chain = new LLMChain({ llm: model, prompt, verbose: process.env.DEBUG === 'true' ? true : false }) + const chain = new LLMChain({ llm: model, outputParser: llmOutputParser, prompt, verbose: process.env.DEBUG === 'true' }) return chain } else if (output === 'outputPrediction') { - const chain = new LLMChain({ llm: model, prompt, verbose: process.env.DEBUG === 'true' ? true : false }) + const chain = new LLMChain({ + llm: model, + outputParser: llmOutputParser, + prompt, + verbose: process.env.DEBUG === 'true' + }) const inputVariables = chain.prompt.inputVariables as string[] // ["product"] const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData) // eslint-disable-next-line no-console @@ -98,7 +105,7 @@ class LLMChain_Chains implements INode { let promptValues: ICommonObject | undefined = nodeData.inputs?.prompt.promptValues as ICommonObject const outputParser = nodeData.inputs?.outputParser as BaseOutputParser promptValues = injectOutputParser(outputParser, chain, promptValues) - const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData, outputParser) + const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData) // eslint-disable-next-line no-console console.log('\x1b[93m\x1b[1m\n*****FINAL RESULT*****\n\x1b[0m\x1b[0m') // eslint-disable-next-line no-console @@ -109,12 +116,11 @@ class LLMChain_Chains implements INode { const runPrediction = async ( inputVariables: string[], - chain: LLMChain, + chain: LLMChain, input: string, promptValuesRaw: ICommonObject | undefined, options: ICommonObject, - nodeData: INodeData, - outputParser: BaseOutputParser | undefined = undefined + nodeData: INodeData ) => { const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) @@ -146,10 +152,10 @@ const runPrediction = async ( if (isStreaming) { const handler = new CustomChainHandler(socketIO, socketIOClientId) const res = await chain.call(options, [loggerHandler, handler, ...callbacks]) - return applyOutputParser(res?.text, outputParser) + return res?.text } else { const res = await chain.call(options, [loggerHandler, ...callbacks]) - return applyOutputParser(res?.text, outputParser) + return res?.text } } else if (seen.length === 1) { // If one inputVariable is not specify, use input (user's question) as value @@ -162,10 +168,10 @@ const runPrediction = async ( if (isStreaming) { const handler = new CustomChainHandler(socketIO, socketIOClientId) const res = await chain.call(options, [loggerHandler, handler, ...callbacks]) - return applyOutputParser(res?.text, outputParser) + return res?.text } else { const res = await chain.call(options, [loggerHandler, ...callbacks]) - return applyOutputParser(res?.text, outputParser) + return res?.text } } else { throw new Error(`Please provide Prompt Values for: ${seen.join(', ')}`) @@ -174,10 +180,10 @@ const runPrediction = async ( if (isStreaming) { const handler = new CustomChainHandler(socketIO, socketIOClientId) const res = await chain.run(input, [loggerHandler, handler, ...callbacks]) - return applyOutputParser(res, outputParser) + return res } else { const res = await chain.run(input, [loggerHandler, ...callbacks]) - return applyOutputParser(res, outputParser) + return res } } }