From 6159fa57ef6beab3cfd7c93d32aa9ea274a03c98 Mon Sep 17 00:00:00 2001 From: vinodkiran Date: Thu, 26 Oct 2023 10:12:51 +0530 Subject: [PATCH] Code cleanup and minor fixes. --- .../nodes/chains/LLMChain/LLMChain.ts | 57 ++++--------------- .../outputparsers/OutputParserHelpers.ts | 53 +++++++++++++++++ .../csvlist/CSVListOutputParser.ts | 3 +- .../customlist/CustomListOutputParser.ts | 3 +- .../structured/StructuredOutputParser.ts | 17 ++---- 5 files changed, 74 insertions(+), 59 deletions(-) create mode 100644 packages/components/nodes/outputparsers/OutputParserHelpers.ts diff --git a/packages/components/nodes/chains/LLMChain/LLMChain.ts b/packages/components/nodes/chains/LLMChain/LLMChain.ts index 5eca8823a..6d041e504 100644 --- a/packages/components/nodes/chains/LLMChain/LLMChain.ts +++ b/packages/components/nodes/chains/LLMChain/LLMChain.ts @@ -4,7 +4,7 @@ import { LLMChain } from 'langchain/chains' import { BaseLanguageModel } from 'langchain/base_language' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { BaseOutputParser } from 'langchain/schema/output_parser' -import { ChatPromptTemplate, FewShotPromptTemplate, PromptTemplate, SystemMessagePromptTemplate } from 'langchain/prompts' +import { injectOutputParser, applyOutputParser } from '../../outputparsers/OutputParserHelpers' class LLMChain_Chains implements INode { label: string @@ -21,7 +21,7 @@ class LLMChain_Chains implements INode { constructor() { this.label = 'LLM Chain' this.name = 'llmChain' - this.version = 2.0 + this.version = 3.0 this.type = 'LLMChain' this.icon = 'chain.svg' this.category = 'Chains' @@ -95,30 +95,9 @@ class LLMChain_Chains implements INode { async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const inputVariables = nodeData.instance.prompt.inputVariables as string[] // ["product"] const chain = nodeData.instance as LLMChain - let promptValues = nodeData.inputs?.prompt.promptValues as ICommonObject + let promptValues: ICommonObject | undefined = nodeData.inputs?.prompt.promptValues as ICommonObject const outputParser = nodeData.inputs?.outputParser as BaseOutputParser - if (outputParser && chain.prompt) { - const formatInstructions = outputParser.getFormatInstructions() - if (chain.prompt instanceof PromptTemplate) { - let pt = chain.prompt - pt.template = pt.template + '\n{format_instructions}' - chain.prompt.partialVariables = { format_instructions: formatInstructions } - } else if (chain.prompt instanceof ChatPromptTemplate) { - let pt = chain.prompt - pt.promptMessages.forEach((msg) => { - if (msg instanceof SystemMessagePromptTemplate) { - ;(msg.prompt as any).partialVariables = { format_instructions: outputParser.getFormatInstructions() } - ;(msg.prompt as any).template = ((msg.prompt as any).template + '\n{format_instructions}') as string - } - }) - } else if (chain.prompt instanceof FewShotPromptTemplate) { - chain.prompt.examplePrompt.partialVariables = { format_instructions: formatInstructions } - chain.prompt.examplePrompt.template = chain.prompt.examplePrompt.template + '\n{format_instructions}' - } - - chain.prompt.inputVariables.push('format_instructions') - promptValues = { ...promptValues, format_instructions: outputParser.getFormatInstructions() } - } + promptValues = injectOutputParser(outputParser, chain, promptValues) const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData, outputParser) // eslint-disable-next-line no-console console.log('\x1b[93m\x1b[1m\n*****FINAL RESULT*****\n\x1b[0m\x1b[0m') @@ -132,7 +111,7 @@ const runPrediction = async ( inputVariables: string[], chain: LLMChain, input: string, - promptValuesRaw: ICommonObject, + promptValuesRaw: ICommonObject | undefined, options: ICommonObject, nodeData: INodeData, outputParser: BaseOutputParser | undefined = undefined @@ -167,10 +146,10 @@ const runPrediction = async ( if (isStreaming) { const handler = new CustomChainHandler(socketIO, socketIOClientId) const res = await chain.call(options, [loggerHandler, handler, ...callbacks]) - return runOutputParser(res?.text, outputParser) + return applyOutputParser(res?.text, outputParser) } else { const res = await chain.call(options, [loggerHandler, ...callbacks]) - return runOutputParser(res?.text, outputParser) + return applyOutputParser(res?.text, outputParser) } } else if (seen.length === 1) { // If one inputVariable is not specify, use input (user's question) as value @@ -183,10 +162,10 @@ const runPrediction = async ( if (isStreaming) { const handler = new CustomChainHandler(socketIO, socketIOClientId) const res = await chain.call(options, [loggerHandler, handler, ...callbacks]) - return runOutputParser(res?.text, outputParser) + return applyOutputParser(res?.text, outputParser) } else { const res = await chain.call(options, [loggerHandler, ...callbacks]) - return runOutputParser(res?.text, outputParser) + return applyOutputParser(res?.text, outputParser) } } else { throw new Error(`Please provide Prompt Values for: ${seen.join(', ')}`) @@ -195,26 +174,12 @@ const runPrediction = async ( if (isStreaming) { const handler = new CustomChainHandler(socketIO, socketIOClientId) const res = await chain.run(input, [loggerHandler, handler, ...callbacks]) - return runOutputParser(res, outputParser) + return applyOutputParser(res, outputParser) } else { const res = await chain.run(input, [loggerHandler, ...callbacks]) - return runOutputParser(res, outputParser) + return applyOutputParser(res, outputParser) } } } -const runOutputParser = async (response: string, outputParser: BaseOutputParser | undefined): Promise => { - if (outputParser) { - const parsedResponse = await outputParser.parse(response) - // eslint-disable-next-line no-console - console.log('**** parsedResponse ****', parsedResponse) - if (typeof parsedResponse === 'object') { - return JSON.stringify(parsedResponse) - } else { - return parsedResponse as string - } - } - return response -} - module.exports = { nodeClass: LLMChain_Chains } diff --git a/packages/components/nodes/outputparsers/OutputParserHelpers.ts b/packages/components/nodes/outputparsers/OutputParserHelpers.ts new file mode 100644 index 000000000..87a59170d --- /dev/null +++ b/packages/components/nodes/outputparsers/OutputParserHelpers.ts @@ -0,0 +1,53 @@ +import { BaseOutputParser } from 'langchain/schema/output_parser' +import { LLMChain } from 'langchain/chains' +import { BaseLanguageModel } from 'langchain/base_language' +import { ICommonObject } from '../../src' +import { ChatPromptTemplate, FewShotPromptTemplate, PromptTemplate, SystemMessagePromptTemplate } from 'langchain/prompts' + +export const CATEGORY = 'Output Parser (Experimental)' + +export const applyOutputParser = async (response: string, outputParser: BaseOutputParser | undefined): Promise => { + if (outputParser) { + const parsedResponse = await outputParser.parse(response) + // eslint-disable-next-line no-console + console.log('**** parsedResponse ****', parsedResponse) + if (typeof parsedResponse === 'object') { + return JSON.stringify(parsedResponse) + } else { + return parsedResponse as string + } + } + return response +} + +export const injectOutputParser = ( + outputParser: BaseOutputParser, + chain: LLMChain, + promptValues: ICommonObject | undefined = undefined +) => { + if (outputParser && chain.prompt) { + const formatInstructions = outputParser.getFormatInstructions() + if (chain.prompt instanceof PromptTemplate) { + let pt = chain.prompt + pt.template = pt.template + '\n{format_instructions}' + chain.prompt.partialVariables = { format_instructions: formatInstructions } + } else if (chain.prompt instanceof ChatPromptTemplate) { + let pt = chain.prompt + pt.promptMessages.forEach((msg) => { + if (msg instanceof SystemMessagePromptTemplate) { + ;(msg.prompt as any).partialVariables = { format_instructions: outputParser.getFormatInstructions() } + ;(msg.prompt as any).template = ((msg.prompt as any).template + '\n{format_instructions}') as string + } + }) + } else if (chain.prompt instanceof FewShotPromptTemplate) { + chain.prompt.examplePrompt.partialVariables = { format_instructions: formatInstructions } + chain.prompt.examplePrompt.template = chain.prompt.examplePrompt.template + '\n{format_instructions}' + } + + chain.prompt.inputVariables.push('format_instructions') + if (promptValues) { + promptValues = { ...promptValues, format_instructions: outputParser.getFormatInstructions() } + } + } + return promptValues +} diff --git a/packages/components/nodes/outputparsers/csvlist/CSVListOutputParser.ts b/packages/components/nodes/outputparsers/csvlist/CSVListOutputParser.ts index 04911fb89..4bc87851a 100644 --- a/packages/components/nodes/outputparsers/csvlist/CSVListOutputParser.ts +++ b/packages/components/nodes/outputparsers/csvlist/CSVListOutputParser.ts @@ -1,6 +1,7 @@ import { getBaseClasses, ICommonObject, INode, INodeData, INodeParams } from '../../../src' import { BaseOutputParser } from 'langchain/schema/output_parser' import { CommaSeparatedListOutputParser } from 'langchain/output_parsers' +import { CATEGORY } from '../OutputParserHelpers' class CSVListOutputParser implements INode { label: string @@ -21,7 +22,7 @@ class CSVListOutputParser implements INode { this.type = 'CSVListOutputParser' this.description = 'Parse the output of an LLM call as a comma-separated list of values' this.icon = 'csv.png' - this.category = 'Output Parser' + this.category = CATEGORY this.baseClasses = [this.type, ...getBaseClasses(BaseOutputParser)] this.inputs = [] } diff --git a/packages/components/nodes/outputparsers/customlist/CustomListOutputParser.ts b/packages/components/nodes/outputparsers/customlist/CustomListOutputParser.ts index db05117ec..1e9617c3f 100644 --- a/packages/components/nodes/outputparsers/customlist/CustomListOutputParser.ts +++ b/packages/components/nodes/outputparsers/customlist/CustomListOutputParser.ts @@ -1,6 +1,7 @@ import { getBaseClasses, ICommonObject, INode, INodeData, INodeParams } from '../../../src' import { BaseOutputParser } from 'langchain/schema/output_parser' import { CustomListOutputParser as LangchainCustomListOutputParser } from 'langchain/output_parsers' +import { CATEGORY } from '../OutputParserHelpers' class CustomListOutputParser implements INode { label: string @@ -21,7 +22,7 @@ class CustomListOutputParser implements INode { this.type = 'CustomListOutputParser' this.description = 'Parse the output of an LLM call as a list of values.' this.icon = 'list.png' - this.category = 'Output Parser' + this.category = CATEGORY this.baseClasses = [this.type, ...getBaseClasses(BaseOutputParser)] this.inputs = [ { diff --git a/packages/components/nodes/outputparsers/structured/StructuredOutputParser.ts b/packages/components/nodes/outputparsers/structured/StructuredOutputParser.ts index ce10239b2..ef04de7de 100644 --- a/packages/components/nodes/outputparsers/structured/StructuredOutputParser.ts +++ b/packages/components/nodes/outputparsers/structured/StructuredOutputParser.ts @@ -1,6 +1,7 @@ import { getBaseClasses, ICommonObject, INode, INodeData, INodeParams } from '../../../src' import { BaseOutputParser } from 'langchain/schema/output_parser' import { StructuredOutputParser as LangchainStructuredOutputParser } from 'langchain/output_parsers' +import { CATEGORY } from '../OutputParserHelpers' class StructuredOutputParser implements INode { label: string @@ -21,8 +22,9 @@ class StructuredOutputParser implements INode { this.type = 'StructuredOutputParser' this.description = 'Parse the output of an LLM call into a given (JSON) structure.' this.icon = 'structure.png' - this.category = 'Output Parser' + this.category = CATEGORY this.baseClasses = [this.type, ...getBaseClasses(BaseOutputParser)] + //TODO: To extend the structureType to ZodSchema this.inputs = [ { label: 'Structure Type', @@ -32,10 +34,6 @@ class StructuredOutputParser implements INode { { label: 'Names And Descriptions', name: 'fromNamesAndDescriptions' - }, - { - label: 'Zod Schema', - name: 'fromZodSchema' } ], default: 'fromNamesAndDescriptions' @@ -59,18 +57,15 @@ class StructuredOutputParser implements INode { const structureType = nodeData.inputs?.structureType as string const structure = nodeData.inputs?.structure as string let parsedStructure: any | undefined = undefined - if (structure) { + if (structure && structureType === 'fromNamesAndDescriptions') { try { parsedStructure = JSON.parse(structure) - if (structureType === 'fromZodSchema') { - return LangchainStructuredOutputParser.fromZodSchema(parsedStructure) - } else { - return LangchainStructuredOutputParser.fromNamesAndDescriptions(parsedStructure) - } + return LangchainStructuredOutputParser.fromNamesAndDescriptions(parsedStructure) } catch (exception) { throw new Error('Invalid JSON in StructuredOutputParser: ' + exception) } } + throw new Error('Error creating OutputParser.') } }