diff --git a/packages/components/nodes/chains/LLMChain/LLMChain.ts b/packages/components/nodes/chains/LLMChain/LLMChain.ts index 0544365af..5eca8823a 100644 --- a/packages/components/nodes/chains/LLMChain/LLMChain.ts +++ b/packages/components/nodes/chains/LLMChain/LLMChain.ts @@ -4,7 +4,7 @@ import { LLMChain } from 'langchain/chains' import { BaseLanguageModel } from 'langchain/base_language' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { BaseOutputParser } from 'langchain/schema/output_parser' -import { ChatPromptTemplate, PromptTemplate, SystemMessagePromptTemplate } from 'langchain/prompts' +import { ChatPromptTemplate, FewShotPromptTemplate, PromptTemplate, SystemMessagePromptTemplate } from 'langchain/prompts' class LLMChain_Chains implements INode { label: string @@ -99,29 +99,25 @@ class LLMChain_Chains implements INode { const outputParser = nodeData.inputs?.outputParser as BaseOutputParser if (outputParser && chain.prompt) { const formatInstructions = outputParser.getFormatInstructions() - chain.prompt.inputVariables.push('format_instructions') if (chain.prompt instanceof PromptTemplate) { let pt = chain.prompt pt.template = pt.template + '\n{format_instructions}' chain.prompt.partialVariables = { format_instructions: formatInstructions } - // eslint-disable-next-line no-console - console.log('prompt :: ', chain.prompt) } else if (chain.prompt instanceof ChatPromptTemplate) { let pt = chain.prompt pt.promptMessages.forEach((msg) => { if (msg instanceof SystemMessagePromptTemplate) { ;(msg.prompt as any).partialVariables = { format_instructions: outputParser.getFormatInstructions() } ;(msg.prompt as any).template = ((msg.prompt as any).template + '\n{format_instructions}') as string - // eslint-disable-next-line no-console - console.log(msg) } }) - //pt.template = pt.template + '\n{format_instructions}' + } else if (chain.prompt instanceof FewShotPromptTemplate) { + chain.prompt.examplePrompt.partialVariables = { format_instructions: formatInstructions } + chain.prompt.examplePrompt.template = chain.prompt.examplePrompt.template + '\n{format_instructions}' } + chain.prompt.inputVariables.push('format_instructions') promptValues = { ...promptValues, format_instructions: outputParser.getFormatInstructions() } - // eslint-disable-next-line no-console - console.log('promptValues :: ', promptValues) } const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData, outputParser) // eslint-disable-next-line no-console diff --git a/packages/components/nodes/outputparsers/structured/StructuredOutputParser.ts b/packages/components/nodes/outputparsers/structured/StructuredOutputParser.ts index e935e5fba..ce10239b2 100644 --- a/packages/components/nodes/outputparsers/structured/StructuredOutputParser.ts +++ b/packages/components/nodes/outputparsers/structured/StructuredOutputParser.ts @@ -62,15 +62,15 @@ class StructuredOutputParser implements INode { if (structure) { try { parsedStructure = JSON.parse(structure) + if (structureType === 'fromZodSchema') { + return LangchainStructuredOutputParser.fromZodSchema(parsedStructure) + } else { + return LangchainStructuredOutputParser.fromNamesAndDescriptions(parsedStructure) + } } catch (exception) { throw new Error('Invalid JSON in StructuredOutputParser: ' + exception) } } - if (structureType === 'fromZodSchema') { - return LangchainStructuredOutputParser.fromZodSchema(parsedStructure) - } else { - return LangchainStructuredOutputParser.fromNamesAndDescriptions(parsedStructure) - } } }