From 2e42dfb635436b963980fbc82f563425b5124927 Mon Sep 17 00:00:00 2001 From: Henry Heng Date: Mon, 23 Jun 2025 19:10:41 +0100 Subject: [PATCH] Bugfix/Gemini Structured Output (#4713) * fix gemini structured output * update issues templates --- .github/ISSUE_TEMPLATE/bug_report.yml | 2 +- .github/ISSUE_TEMPLATE/feature_request.yml | 2 +- .../ChatGoogleGenerativeAI.ts | 16 +- .../FlowiseChatGoogleGenerativeAI.ts | 721 +---------------- .../FlowiseChatGoogleGenerativeAIBackup.ts | 733 ++++++++++++++++++ .../multiagents/Supervisor/Supervisor.ts | 2 +- .../ExtractMetadataRetriever.ts | 18 +- .../ConditionAgent/ConditionAgent.ts | 18 +- .../nodes/sequentialagents/LLMNode/LLMNode.ts | 17 +- packages/components/src/followUpPrompts.ts | 17 +- 10 files changed, 756 insertions(+), 790 deletions(-) create mode 100644 packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAIBackup.ts diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 0bf6516bd..5d6d81c08 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,6 +1,6 @@ name: Bug Report description: File a bug report to help us improve -title: '[BUG] ' +title: '' labels: ['bug'] assignees: [] body: diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 0a6742393..4c0ba3e59 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -1,6 +1,6 @@ name: Feature Request description: Suggest a new feature or enhancement for Flowise -title: '[FEATURE] ' +title: '' labels: ['enhancement'] assignees: [] body: diff --git a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts index b42ab4077..58b78b343 100644 --- a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts +++ b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts @@ -4,8 +4,8 @@ import { BaseCache } from '@langchain/core/caches' import { ICommonObject, IMultiModalOption, INode, INodeData, INodeOptionsValue, INodeParams } from '../../../src/Interface' import { convertMultiOptionsToStringArray, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' import { getModels, MODEL_TYPE } from '../../../src/modelLoader' -import { ChatGoogleGenerativeAI, GoogleGenerativeAIChatInput } from './FlowiseChatGoogleGenerativeAI' -import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager' +import { ChatGoogleGenerativeAI } from './FlowiseChatGoogleGenerativeAI' +import { GoogleGenerativeAIChatInput } from '@langchain/google-genai' class GoogleGenerativeAI_ChatModels implements INode { label: string @@ -43,12 +43,6 @@ class GoogleGenerativeAI_ChatModels implements INode { type: 'BaseCache', optional: true }, - { - label: 'Context Cache', - name: 'contextCache', - type: 'GoogleAICacheManager', - optional: true - }, { label: 'Model Name', name: 'modelName', @@ -204,15 +198,14 @@ class GoogleGenerativeAI_ChatModels implements INode { const harmCategory = nodeData.inputs?.harmCategory as string const harmBlockThreshold = nodeData.inputs?.harmBlockThreshold as string const cache = nodeData.inputs?.cache as BaseCache - const contextCache = nodeData.inputs?.contextCache as FlowiseGoogleAICacheManager const streaming = nodeData.inputs?.streaming as boolean const baseUrl = nodeData.inputs?.baseUrl as string | undefined const allowImageUploads = nodeData.inputs?.allowImageUploads as boolean - const obj: Partial = { + const obj: GoogleGenerativeAIChatInput = { apiKey: apiKey, - modelName: customModelName || modelName, + model: customModelName || modelName, streaming: streaming ?? true } @@ -248,7 +241,6 @@ class GoogleGenerativeAI_ChatModels implements INode { const model = new ChatGoogleGenerativeAI(nodeData.id, obj) model.setMultiModalOption(multiModalOption) - if (contextCache) model.setContextCache(contextCache) return model } diff --git a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI.ts b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI.ts index 8854485c9..13a3b0032 100644 --- a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI.ts +++ b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI.ts @@ -1,374 +1,5 @@ -import { BaseMessage, AIMessage, AIMessageChunk, isBaseMessage, ChatMessage, MessageContentComplex } from '@langchain/core/messages' -import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager' -import { BaseChatModel, type BaseChatModelParams } from '@langchain/core/language_models/chat_models' -import { ChatGeneration, ChatGenerationChunk, ChatResult } from '@langchain/core/outputs' -import { ToolCallChunk } from '@langchain/core/messages/tool' -import { NewTokenIndices } from '@langchain/core/callbacks/base' -import { - EnhancedGenerateContentResponse, - Content, - Part, - Tool, - GenerativeModel, - GoogleGenerativeAI as GenerativeAI -} from '@google/generative-ai' -import type { - FunctionCallPart, - FunctionResponsePart, - SafetySetting, - UsageMetadata, - FunctionDeclarationsTool as GoogleGenerativeAIFunctionDeclarationsTool, - GenerateContentRequest -} from '@google/generative-ai' -import { ICommonObject, IMultiModalOption, IVisionChatModal } from '../../../src' -import { StructuredToolInterface } from '@langchain/core/tools' -import { isStructuredTool } from '@langchain/core/utils/function_calling' -import { zodToJsonSchema } from 'zod-to-json-schema' -import { BaseLanguageModelCallOptions } from '@langchain/core/language_models/base' -import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager' - -const DEFAULT_IMAGE_MAX_TOKEN = 8192 -const DEFAULT_IMAGE_MODEL = 'gemini-1.5-flash-latest' - -interface TokenUsage { - completionTokens?: number - promptTokens?: number - totalTokens?: number -} - -interface GoogleGenerativeAIChatCallOptions extends BaseLanguageModelCallOptions { - tools?: StructuredToolInterface[] | GoogleGenerativeAIFunctionDeclarationsTool[] - /** - * Whether or not to include usage data, like token counts - * in the streamed response chunks. - * @default true - */ - streamUsage?: boolean -} - -export interface GoogleGenerativeAIChatInput extends BaseChatModelParams, Pick { - modelName?: string - model?: string - temperature?: number - maxOutputTokens?: number - topP?: number - topK?: number - stopSequences?: string[] - safetySettings?: SafetySetting[] - apiKey?: string - apiVersion?: string - baseUrl?: string - streaming?: boolean -} - -class LangchainChatGoogleGenerativeAI - extends BaseChatModel - implements GoogleGenerativeAIChatInput -{ - modelName = 'gemini-pro' - - temperature?: number - - maxOutputTokens?: number - - topP?: number - - topK?: number - - stopSequences: string[] = [] - - safetySettings?: SafetySetting[] - - apiKey?: string - - baseUrl?: string - - streaming = false - - streamUsage = true - - private client: GenerativeModel - - private contextCache?: FlowiseGoogleAICacheManager - - get _isMultimodalModel() { - return true - } - - constructor(fields?: GoogleGenerativeAIChatInput) { - super(fields ?? {}) - - this.modelName = fields?.model?.replace(/^models\//, '') ?? fields?.modelName?.replace(/^models\//, '') ?? 'gemini-pro' - - this.maxOutputTokens = fields?.maxOutputTokens ?? this.maxOutputTokens - - if (this.maxOutputTokens && this.maxOutputTokens < 0) { - throw new Error('`maxOutputTokens` must be a positive integer') - } - - this.temperature = fields?.temperature ?? this.temperature - if (this.temperature && (this.temperature < 0 || this.temperature > 1)) { - throw new Error('`temperature` must be in the range of [0.0,1.0]') - } - - this.topP = fields?.topP ?? this.topP - if (this.topP && this.topP < 0) { - throw new Error('`topP` must be a positive integer') - } - - if (this.topP && this.topP > 1) { - throw new Error('`topP` must be below 1.') - } - - this.topK = fields?.topK ?? this.topK - if (this.topK && this.topK < 0) { - throw new Error('`topK` must be a positive integer') - } - - this.stopSequences = fields?.stopSequences ?? this.stopSequences - - this.apiKey = fields?.apiKey ?? process.env['GOOGLE_API_KEY'] - if (!this.apiKey) { - throw new Error( - 'Please set an API key for Google GenerativeAI ' + - 'in the environment variable GOOGLE_API_KEY ' + - 'or in the `apiKey` field of the ' + - 'ChatGoogleGenerativeAI constructor' - ) - } - - this.safetySettings = fields?.safetySettings ?? this.safetySettings - if (this.safetySettings && this.safetySettings.length > 0) { - const safetySettingsSet = new Set(this.safetySettings.map((s) => s.category)) - if (safetySettingsSet.size !== this.safetySettings.length) { - throw new Error('The categories in `safetySettings` array must be unique') - } - } - - this.streaming = fields?.streaming ?? this.streaming - - this.streamUsage = fields?.streamUsage ?? this.streamUsage - - this.getClient() - } - - async getClient(prompt?: Content[], tools?: Tool[]) { - this.client = new GenerativeAI(this.apiKey ?? '').getGenerativeModel( - { - model: this.modelName, - tools, - safetySettings: this.safetySettings as SafetySetting[], - generationConfig: { - candidateCount: 1, - stopSequences: this.stopSequences, - maxOutputTokens: this.maxOutputTokens, - temperature: this.temperature, - topP: this.topP, - topK: this.topK - } - }, - { - baseUrl: this.baseUrl - } - ) - if (this.contextCache) { - const cachedContent = await this.contextCache.lookup({ - contents: prompt ? [{ ...prompt[0], parts: prompt[0].parts.slice(0, 1) }] : [], - model: this.modelName, - tools - }) - this.client.cachedContent = cachedContent as any - } - } - - _combineLLMOutput() { - return [] - } - - _llmType() { - return 'googlegenerativeai' - } - - override bindTools(tools: (StructuredToolInterface | Record)[], kwargs?: Partial) { - //@ts-ignore - return this.bind({ tools: convertToGeminiTools(tools), ...kwargs }) - } - - invocationParams(options?: this['ParsedCallOptions']): Omit { - const tools = options?.tools as GoogleGenerativeAIFunctionDeclarationsTool[] | StructuredToolInterface[] | undefined - if (Array.isArray(tools) && !tools.some((t: any) => !('lc_namespace' in t))) { - return { - tools: convertToGeminiTools(options?.tools as StructuredToolInterface[]) as any - } - } - return { - tools: options?.tools as GoogleGenerativeAIFunctionDeclarationsTool[] | undefined - } - } - - convertFunctionResponse(prompts: Content[]) { - for (let i = 0; i < prompts.length; i += 1) { - if (prompts[i].role === 'function') { - if (prompts[i - 1].role === 'model') { - const toolName = prompts[i - 1].parts[0].functionCall?.name ?? '' - prompts[i].parts = [ - { - functionResponse: { - name: toolName, - response: { - name: toolName, - content: prompts[i].parts[0].text - } - } - } - ] - } - } - } - } - - setContextCache(contextCache: FlowiseGoogleAICacheManager): void { - this.contextCache = contextCache - } - - async getNumTokens(prompt: BaseMessage[]) { - const contents = convertBaseMessagesToContent(prompt, this._isMultimodalModel) - const { totalTokens } = await this.client.countTokens({ contents }) - return totalTokens - } - - async _generateNonStreaming( - prompt: Content[], - options: this['ParsedCallOptions'], - _runManager?: CallbackManagerForLLMRun - ): Promise { - //@ts-ignore - const tools = options.tools ?? [] - - this.convertFunctionResponse(prompt) - - if (tools.length > 0) { - await this.getClient(prompt, tools as Tool[]) - } else { - await this.getClient(prompt) - } - const res = await this.caller.callWithOptions({ signal: options?.signal }, async () => { - let output - try { - output = await this.client.generateContent({ - contents: prompt - }) - } catch (e: any) { - if (e.message?.includes('400 Bad Request')) { - e.status = 400 - } - throw e - } - return output - }) - const generationResult = mapGenerateContentResultToChatResult(res.response) - await _runManager?.handleLLMNewToken(generationResult.generations?.length ? generationResult.generations[0].text : '') - return generationResult - } - - async _generate( - messages: BaseMessage[], - options: this['ParsedCallOptions'], - runManager?: CallbackManagerForLLMRun - ): Promise { - let prompt = convertBaseMessagesToContent(messages, this._isMultimodalModel) - prompt = checkIfEmptyContentAndSameRole(prompt) - - // Handle streaming - if (this.streaming) { - const tokenUsage: TokenUsage = {} - const stream = this._streamResponseChunks(messages, options, runManager) - const finalChunks: Record = {} - - for await (const chunk of stream) { - const index = (chunk.generationInfo as NewTokenIndices)?.completion ?? 0 - if (finalChunks[index] === undefined) { - finalChunks[index] = chunk - } else { - finalChunks[index] = finalChunks[index].concat(chunk) - } - } - const generations = Object.entries(finalChunks) - .sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10)) - .map(([_, value]) => value) - - return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } } - } - return this._generateNonStreaming(prompt, options, runManager) - } - - async *_streamResponseChunks( - messages: BaseMessage[], - options: this['ParsedCallOptions'], - runManager?: CallbackManagerForLLMRun - ): AsyncGenerator { - let prompt = convertBaseMessagesToContent(messages, this._isMultimodalModel) - prompt = checkIfEmptyContentAndSameRole(prompt) - - const parameters = this.invocationParams(options) - const request = { - ...parameters, - contents: prompt - } - - const tools = options.tools ?? [] - if (tools.length > 0) { - await this.getClient(prompt, tools as Tool[]) - } else { - await this.getClient(prompt) - } - - const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => { - const { stream } = await this.client.generateContentStream(request) - return stream - }) - - let usageMetadata: UsageMetadata | ICommonObject | undefined - let index = 0 - for await (const response of stream) { - if ('usageMetadata' in response && this.streamUsage !== false && options.streamUsage !== false) { - const genAIUsageMetadata = response.usageMetadata as { - promptTokenCount: number - candidatesTokenCount: number - totalTokenCount: number - } - if (!usageMetadata) { - usageMetadata = { - input_tokens: genAIUsageMetadata.promptTokenCount, - output_tokens: genAIUsageMetadata.candidatesTokenCount, - total_tokens: genAIUsageMetadata.totalTokenCount - } - } else { - // Under the hood, LangChain combines the prompt tokens. Google returns the updated - // total each time, so we need to find the difference between the tokens. - const outputTokenDiff = genAIUsageMetadata.candidatesTokenCount - (usageMetadata as ICommonObject).output_tokens - usageMetadata = { - input_tokens: 0, - output_tokens: outputTokenDiff, - total_tokens: outputTokenDiff - } - } - } - - const chunk = convertResponseContentToChatGenerationChunk(response, { - usageMetadata: usageMetadata as UsageMetadata, - index - }) - index += 1 - if (!chunk) { - continue - } - - yield chunk - await runManager?.handleLLMNewToken(chunk.text ?? '') - } - } -} +import { ChatGoogleGenerativeAI as LangchainChatGoogleGenerativeAI, GoogleGenerativeAIChatInput } from '@langchain/google-genai' +import { IMultiModalOption, IVisionChatModal } from '../../../src' export class ChatGoogleGenerativeAI extends LangchainChatGoogleGenerativeAI implements IVisionChatModal { configuredModel: string @@ -376,15 +7,15 @@ export class ChatGoogleGenerativeAI extends LangchainChatGoogleGenerativeAI impl multiModalOption: IMultiModalOption id: string - constructor(id: string, fields?: GoogleGenerativeAIChatInput) { + constructor(id: string, fields: GoogleGenerativeAIChatInput) { super(fields) this.id = id - this.configuredModel = fields?.modelName ?? '' + this.configuredModel = fields?.model ?? '' this.configuredMaxToken = fields?.maxOutputTokens } revertToOriginalModel(): void { - this.modelName = this.configuredModel + this.model = this.configuredModel this.maxOutputTokens = this.configuredMaxToken } @@ -393,346 +24,6 @@ export class ChatGoogleGenerativeAI extends LangchainChatGoogleGenerativeAI impl } setVisionModel(): void { - if (this.modelName === 'gemini-1.0-pro-latest') { - this.modelName = DEFAULT_IMAGE_MODEL - this.maxOutputTokens = this.configuredMaxToken ? this.configuredMaxToken : DEFAULT_IMAGE_MAX_TOKEN - } + // pass } } - -function messageContentMedia(content: MessageContentComplex): Part { - if ('mimeType' in content && 'data' in content) { - return { - inlineData: { - mimeType: content.mimeType, - data: content.data - } - } - } - - throw new Error('Invalid media content') -} - -function getMessageAuthor(message: BaseMessage) { - const type = message._getType() - if (ChatMessage.isInstance(message)) { - return message.role - } - return message.name ?? type -} - -function convertAuthorToRole(author: string) { - switch (author.toLowerCase()) { - case 'ai': - case 'assistant': - case 'model': - return 'model' - case 'function': - case 'tool': - return 'function' - case 'system': - case 'human': - default: - return 'user' - } -} - -function convertMessageContentToParts(message: BaseMessage, isMultimodalModel: boolean): Part[] { - if (typeof message.content === 'string' && message.content !== '') { - return [{ text: message.content }] - } - - let functionCalls: FunctionCallPart[] = [] - let functionResponses: FunctionResponsePart[] = [] - let messageParts: Part[] = [] - - if ('tool_calls' in message && Array.isArray(message.tool_calls) && message.tool_calls.length > 0) { - functionCalls = message.tool_calls.map((tc) => ({ - functionCall: { - name: tc.name, - args: tc.args - } - })) - } else if (message._getType() === 'tool' && message.name && message.content) { - functionResponses = [ - { - functionResponse: { - name: message.name, - response: message.content - } - } - ] - } else if (Array.isArray(message.content)) { - messageParts = message.content.map((c) => { - if (c.type === 'text') { - return { - text: c.text - } - } - - if (c.type === 'image_url') { - if (!isMultimodalModel) { - throw new Error(`This model does not support images`) - } - let source - if (typeof c.image_url === 'string') { - source = c.image_url - } else if (typeof c.image_url === 'object' && 'url' in c.image_url) { - source = c.image_url.url - } else { - throw new Error('Please provide image as base64 encoded data URL') - } - const [dm, data] = source.split(',') - if (!dm.startsWith('data:')) { - throw new Error('Please provide image as base64 encoded data URL') - } - - const [mimeType, encoding] = dm.replace(/^data:/, '').split(';') - if (encoding !== 'base64') { - throw new Error('Please provide image as base64 encoded data URL') - } - - return { - inlineData: { - data, - mimeType - } - } - } else if (c.type === 'media') { - return messageContentMedia(c) - } else if (c.type === 'tool_use') { - return { - functionCall: { - name: c.name, - args: c.input - } - } - } - throw new Error(`Unknown content type ${(c as { type: string }).type}`) - }) - } - - return [...messageParts, ...functionCalls, ...functionResponses] -} - -/* - * This is a dedicated logic for Multi Agent Supervisor to handle the case where the content is empty, and the role is the same - */ - -function checkIfEmptyContentAndSameRole(contents: Content[]) { - let prevRole = '' - const validContents: Content[] = [] - - for (const content of contents) { - // Skip only if completely empty - if (!content.parts || !content.parts.length) { - continue - } - - // Ensure role is always either 'user' or 'model' - content.role = content.role === 'model' ? 'model' : 'user' - - // Handle consecutive messages - if (content.role === prevRole && validContents.length > 0) { - // Merge with previous content if same role - validContents[validContents.length - 1].parts.push(...content.parts) - continue - } - - validContents.push(content) - prevRole = content.role - } - - return validContents -} - -function convertBaseMessagesToContent(messages: BaseMessage[], isMultimodalModel: boolean) { - return messages.reduce<{ - content: Content[] - mergeWithPreviousContent: boolean - }>( - (acc, message, index) => { - if (!isBaseMessage(message)) { - throw new Error('Unsupported message input') - } - const author = getMessageAuthor(message) - if (author === 'system' && index !== 0) { - throw new Error('System message should be the first one') - } - const role = convertAuthorToRole(author) - - const prevContent = acc.content[acc.content.length] - if (!acc.mergeWithPreviousContent && prevContent && prevContent.role === role) { - throw new Error('Google Generative AI requires alternate messages between authors') - } - - const parts = convertMessageContentToParts(message, isMultimodalModel) - - if (acc.mergeWithPreviousContent) { - const prevContent = acc.content[acc.content.length - 1] - if (!prevContent) { - throw new Error('There was a problem parsing your system message. Please try a prompt without one.') - } - prevContent.parts.push(...parts) - - return { - mergeWithPreviousContent: false, - content: acc.content - } - } - let actualRole = role - if (actualRole === 'function' || actualRole === 'tool') { - // GenerativeAI API will throw an error if the role is not "user" or "model." - actualRole = 'user' - } - const content: Content = { - role: actualRole, - parts - } - return { - mergeWithPreviousContent: author === 'system', - content: [...acc.content, content] - } - }, - { content: [], mergeWithPreviousContent: false } - ).content -} - -function mapGenerateContentResultToChatResult( - response: EnhancedGenerateContentResponse, - extra?: { - usageMetadata: UsageMetadata | undefined - } -): ChatResult { - // if rejected or error, return empty generations with reason in filters - if (!response.candidates || response.candidates.length === 0 || !response.candidates[0]) { - return { - generations: [], - llmOutput: { - filters: response.promptFeedback - } - } - } - - const functionCalls = response.functionCalls() - const [candidate] = response.candidates - const { content, ...generationInfo } = candidate - const text = content?.parts[0]?.text ?? '' - - const generation: ChatGeneration = { - text, - message: new AIMessage({ - content: text, - tool_calls: functionCalls, - additional_kwargs: { - ...generationInfo - }, - usage_metadata: extra?.usageMetadata as any - }), - generationInfo - } - - return { - generations: [generation] - } -} - -function convertResponseContentToChatGenerationChunk( - response: EnhancedGenerateContentResponse, - extra: { - usageMetadata?: UsageMetadata | undefined - index: number - } -): ChatGenerationChunk | null { - if (!response || !response.candidates || response.candidates.length === 0) { - return null - } - const functionCalls = response.functionCalls() - const [candidate] = response.candidates - const { content, ...generationInfo } = candidate - const text = content?.parts?.[0]?.text ?? '' - - const toolCallChunks: ToolCallChunk[] = [] - if (functionCalls) { - toolCallChunks.push( - ...functionCalls.map((fc) => ({ - ...fc, - args: JSON.stringify(fc.args), - index: extra.index - })) - ) - } - return new ChatGenerationChunk({ - text, - message: new AIMessageChunk({ - content: text, - name: !content ? undefined : content.role, - tool_call_chunks: toolCallChunks, - // Each chunk can have unique "generationInfo", and merging strategy is unclear, - // so leave blank for now. - additional_kwargs: {}, - usage_metadata: extra.usageMetadata as any - }), - generationInfo - }) -} - -function zodToGeminiParameters(zodObj: any) { - // Gemini doesn't accept either the $schema or additionalProperties - // attributes, so we need to explicitly remove them. - const jsonSchema: any = zodToJsonSchema(zodObj) - // eslint-disable-next-line unused-imports/no-unused-vars - const { $schema, additionalProperties, ...rest } = jsonSchema - - // Ensure all properties have type specified - if (rest.properties) { - Object.keys(rest.properties).forEach((key) => { - const prop = rest.properties[key] - - // Handle enum types - if (prop.enum?.length) { - rest.properties[key] = { - type: 'string', - format: 'enum', - enum: prop.enum - } - } - // Handle missing type - else if (!prop.type && !prop.oneOf && !prop.anyOf && !prop.allOf) { - // Infer type from other properties - if (prop.minimum !== undefined || prop.maximum !== undefined) { - prop.type = 'number' - } else if (prop.format === 'date-time') { - prop.type = 'string' - } else if (prop.items) { - prop.type = 'array' - } else if (prop.properties) { - prop.type = 'object' - } else { - // Default to string if type can't be inferred - prop.type = 'string' - } - } - }) - } - - return rest -} - -function convertToGeminiTools(structuredTools: (StructuredToolInterface | Record)[]) { - return [ - { - functionDeclarations: structuredTools.map((structuredTool) => { - if (isStructuredTool(structuredTool)) { - const jsonSchema = zodToGeminiParameters(structuredTool.schema) - return { - name: structuredTool.name, - description: structuredTool.description, - parameters: jsonSchema - } - } - return structuredTool - }) - } - ] -} diff --git a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAIBackup.ts b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAIBackup.ts new file mode 100644 index 000000000..9095e6ae6 --- /dev/null +++ b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAIBackup.ts @@ -0,0 +1,733 @@ +/** Disabled due to the withStructuredOutput + +import { BaseMessage, AIMessage, AIMessageChunk, isBaseMessage, ChatMessage, MessageContentComplex } from '@langchain/core/messages' +import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager' +import { BaseChatModel, type BaseChatModelParams } from '@langchain/core/language_models/chat_models' +import { ChatGeneration, ChatGenerationChunk, ChatResult } from '@langchain/core/outputs' +import { ToolCallChunk } from '@langchain/core/messages/tool' +import { NewTokenIndices } from '@langchain/core/callbacks/base' +import { + EnhancedGenerateContentResponse, + Content, + Part, + Tool, + GenerativeModel, + GoogleGenerativeAI as GenerativeAI +} from '@google/generative-ai' +import type { + FunctionCallPart, + FunctionResponsePart, + SafetySetting, + UsageMetadata, + FunctionDeclarationsTool as GoogleGenerativeAIFunctionDeclarationsTool, + GenerateContentRequest +} from '@google/generative-ai' +import { ICommonObject, IMultiModalOption, IVisionChatModal } from '../../../src' +import { StructuredToolInterface } from '@langchain/core/tools' +import { isStructuredTool } from '@langchain/core/utils/function_calling' +import { zodToJsonSchema } from 'zod-to-json-schema' +import { BaseLanguageModelCallOptions } from '@langchain/core/language_models/base' +import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager' + +const DEFAULT_IMAGE_MAX_TOKEN = 8192 +const DEFAULT_IMAGE_MODEL = 'gemini-1.5-flash-latest' + +interface TokenUsage { + completionTokens?: number + promptTokens?: number + totalTokens?: number +} + +interface GoogleGenerativeAIChatCallOptions extends BaseLanguageModelCallOptions { + tools?: StructuredToolInterface[] | GoogleGenerativeAIFunctionDeclarationsTool[] + streamUsage?: boolean +} + +export interface GoogleGenerativeAIChatInput extends BaseChatModelParams, Pick { + modelName?: string + model?: string + temperature?: number + maxOutputTokens?: number + topP?: number + topK?: number + stopSequences?: string[] + safetySettings?: SafetySetting[] + apiKey?: string + apiVersion?: string + baseUrl?: string + streaming?: boolean +} + +class LangchainChatGoogleGenerativeAI + extends BaseChatModel + implements GoogleGenerativeAIChatInput +{ + modelName = 'gemini-pro' + + temperature?: number + + maxOutputTokens?: number + + topP?: number + + topK?: number + + stopSequences: string[] = [] + + safetySettings?: SafetySetting[] + + apiKey?: string + + baseUrl?: string + + streaming = false + + streamUsage = true + + private client: GenerativeModel + + private contextCache?: FlowiseGoogleAICacheManager + + get _isMultimodalModel() { + return true + } + + constructor(fields?: GoogleGenerativeAIChatInput) { + super(fields ?? {}) + + this.modelName = fields?.model?.replace(/^models\//, '') ?? fields?.modelName?.replace(/^models\//, '') ?? 'gemini-pro' + + this.maxOutputTokens = fields?.maxOutputTokens ?? this.maxOutputTokens + + if (this.maxOutputTokens && this.maxOutputTokens < 0) { + throw new Error('`maxOutputTokens` must be a positive integer') + } + + this.temperature = fields?.temperature ?? this.temperature + if (this.temperature && (this.temperature < 0 || this.temperature > 1)) { + throw new Error('`temperature` must be in the range of [0.0,1.0]') + } + + this.topP = fields?.topP ?? this.topP + if (this.topP && this.topP < 0) { + throw new Error('`topP` must be a positive integer') + } + + if (this.topP && this.topP > 1) { + throw new Error('`topP` must be below 1.') + } + + this.topK = fields?.topK ?? this.topK + if (this.topK && this.topK < 0) { + throw new Error('`topK` must be a positive integer') + } + + this.stopSequences = fields?.stopSequences ?? this.stopSequences + + this.apiKey = fields?.apiKey ?? process.env['GOOGLE_API_KEY'] + if (!this.apiKey) { + throw new Error( + 'Please set an API key for Google GenerativeAI ' + + 'in the environment variable GOOGLE_API_KEY ' + + 'or in the `apiKey` field of the ' + + 'ChatGoogleGenerativeAI constructor' + ) + } + + this.safetySettings = fields?.safetySettings ?? this.safetySettings + if (this.safetySettings && this.safetySettings.length > 0) { + const safetySettingsSet = new Set(this.safetySettings.map((s) => s.category)) + if (safetySettingsSet.size !== this.safetySettings.length) { + throw new Error('The categories in `safetySettings` array must be unique') + } + } + + this.streaming = fields?.streaming ?? this.streaming + + this.streamUsage = fields?.streamUsage ?? this.streamUsage + + this.getClient() + } + + async getClient(prompt?: Content[], tools?: Tool[]) { + this.client = new GenerativeAI(this.apiKey ?? '').getGenerativeModel( + { + model: this.modelName, + tools, + safetySettings: this.safetySettings as SafetySetting[], + generationConfig: { + candidateCount: 1, + stopSequences: this.stopSequences, + maxOutputTokens: this.maxOutputTokens, + temperature: this.temperature, + topP: this.topP, + topK: this.topK + } + }, + { + baseUrl: this.baseUrl + } + ) + if (this.contextCache) { + const cachedContent = await this.contextCache.lookup({ + contents: prompt ? [{ ...prompt[0], parts: prompt[0].parts.slice(0, 1) }] : [], + model: this.modelName, + tools + }) + this.client.cachedContent = cachedContent as any + } + } + + _combineLLMOutput() { + return [] + } + + _llmType() { + return 'googlegenerativeai' + } + + override bindTools(tools: (StructuredToolInterface | Record)[], kwargs?: Partial) { + //@ts-ignore + return this.bind({ tools: convertToGeminiTools(tools), ...kwargs }) + } + + invocationParams(options?: this['ParsedCallOptions']): Omit { + const tools = options?.tools as GoogleGenerativeAIFunctionDeclarationsTool[] | StructuredToolInterface[] | undefined + if (Array.isArray(tools) && !tools.some((t: any) => !('lc_namespace' in t))) { + return { + tools: convertToGeminiTools(options?.tools as StructuredToolInterface[]) as any + } + } + return { + tools: options?.tools as GoogleGenerativeAIFunctionDeclarationsTool[] | undefined + } + } + + convertFunctionResponse(prompts: Content[]) { + for (let i = 0; i < prompts.length; i += 1) { + if (prompts[i].role === 'function') { + if (prompts[i - 1].role === 'model') { + const toolName = prompts[i - 1].parts[0].functionCall?.name ?? '' + prompts[i].parts = [ + { + functionResponse: { + name: toolName, + response: { + name: toolName, + content: prompts[i].parts[0].text + } + } + } + ] + } + } + } + } + + setContextCache(contextCache: FlowiseGoogleAICacheManager): void { + this.contextCache = contextCache + } + + async getNumTokens(prompt: BaseMessage[]) { + const contents = convertBaseMessagesToContent(prompt, this._isMultimodalModel) + const { totalTokens } = await this.client.countTokens({ contents }) + return totalTokens + } + + async _generateNonStreaming( + prompt: Content[], + options: this['ParsedCallOptions'], + _runManager?: CallbackManagerForLLMRun + ): Promise { + //@ts-ignore + const tools = options.tools ?? [] + + this.convertFunctionResponse(prompt) + + if (tools.length > 0) { + await this.getClient(prompt, tools as Tool[]) + } else { + await this.getClient(prompt) + } + const res = await this.caller.callWithOptions({ signal: options?.signal }, async () => { + let output + try { + output = await this.client.generateContent({ + contents: prompt + }) + } catch (e: any) { + if (e.message?.includes('400 Bad Request')) { + e.status = 400 + } + throw e + } + return output + }) + const generationResult = mapGenerateContentResultToChatResult(res.response) + await _runManager?.handleLLMNewToken(generationResult.generations?.length ? generationResult.generations[0].text : '') + return generationResult + } + + async _generate( + messages: BaseMessage[], + options: this['ParsedCallOptions'], + runManager?: CallbackManagerForLLMRun + ): Promise { + let prompt = convertBaseMessagesToContent(messages, this._isMultimodalModel) + prompt = checkIfEmptyContentAndSameRole(prompt) + + // Handle streaming + if (this.streaming) { + const tokenUsage: TokenUsage = {} + const stream = this._streamResponseChunks(messages, options, runManager) + const finalChunks: Record = {} + + for await (const chunk of stream) { + const index = (chunk.generationInfo as NewTokenIndices)?.completion ?? 0 + if (finalChunks[index] === undefined) { + finalChunks[index] = chunk + } else { + finalChunks[index] = finalChunks[index].concat(chunk) + } + } + const generations = Object.entries(finalChunks) + .sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10)) + .map(([_, value]) => value) + + return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } } + } + return this._generateNonStreaming(prompt, options, runManager) + } + + async *_streamResponseChunks( + messages: BaseMessage[], + options: this['ParsedCallOptions'], + runManager?: CallbackManagerForLLMRun + ): AsyncGenerator { + let prompt = convertBaseMessagesToContent(messages, this._isMultimodalModel) + prompt = checkIfEmptyContentAndSameRole(prompt) + + const parameters = this.invocationParams(options) + const request = { + ...parameters, + contents: prompt + } + + const tools = options.tools ?? [] + if (tools.length > 0) { + await this.getClient(prompt, tools as Tool[]) + } else { + await this.getClient(prompt) + } + + const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => { + const { stream } = await this.client.generateContentStream(request) + return stream + }) + + let usageMetadata: UsageMetadata | ICommonObject | undefined + let index = 0 + for await (const response of stream) { + if ('usageMetadata' in response && this.streamUsage !== false && options.streamUsage !== false) { + const genAIUsageMetadata = response.usageMetadata as { + promptTokenCount: number + candidatesTokenCount: number + totalTokenCount: number + } + if (!usageMetadata) { + usageMetadata = { + input_tokens: genAIUsageMetadata.promptTokenCount, + output_tokens: genAIUsageMetadata.candidatesTokenCount, + total_tokens: genAIUsageMetadata.totalTokenCount + } + } else { + // Under the hood, LangChain combines the prompt tokens. Google returns the updated + // total each time, so we need to find the difference between the tokens. + const outputTokenDiff = genAIUsageMetadata.candidatesTokenCount - (usageMetadata as ICommonObject).output_tokens + usageMetadata = { + input_tokens: 0, + output_tokens: outputTokenDiff, + total_tokens: outputTokenDiff + } + } + } + + const chunk = convertResponseContentToChatGenerationChunk(response, { + usageMetadata: usageMetadata as UsageMetadata, + index + }) + index += 1 + if (!chunk) { + continue + } + + yield chunk + await runManager?.handleLLMNewToken(chunk.text ?? '') + } + } +} + +export class ChatGoogleGenerativeAI extends LangchainChatGoogleGenerativeAI implements IVisionChatModal { + configuredModel: string + configuredMaxToken?: number + multiModalOption: IMultiModalOption + id: string + + constructor(id: string, fields?: GoogleGenerativeAIChatInput) { + super(fields) + this.id = id + this.configuredModel = fields?.modelName ?? '' + this.configuredMaxToken = fields?.maxOutputTokens + } + + revertToOriginalModel(): void { + this.modelName = this.configuredModel + this.maxOutputTokens = this.configuredMaxToken + } + + setMultiModalOption(multiModalOption: IMultiModalOption): void { + this.multiModalOption = multiModalOption + } + + setVisionModel(): void { + if (this.modelName === 'gemini-1.0-pro-latest') { + this.modelName = DEFAULT_IMAGE_MODEL + this.maxOutputTokens = this.configuredMaxToken ? this.configuredMaxToken : DEFAULT_IMAGE_MAX_TOKEN + } + } +} + +function messageContentMedia(content: MessageContentComplex): Part { + if ('mimeType' in content && 'data' in content) { + return { + inlineData: { + mimeType: content.mimeType, + data: content.data + } + } + } + + throw new Error('Invalid media content') +} + +function getMessageAuthor(message: BaseMessage) { + const type = message._getType() + if (ChatMessage.isInstance(message)) { + return message.role + } + return message.name ?? type +} + +function convertAuthorToRole(author: string) { + switch (author.toLowerCase()) { + case 'ai': + case 'assistant': + case 'model': + return 'model' + case 'function': + case 'tool': + return 'function' + case 'system': + case 'human': + default: + return 'user' + } +} + +function convertMessageContentToParts(message: BaseMessage, isMultimodalModel: boolean): Part[] { + if (typeof message.content === 'string' && message.content !== '') { + return [{ text: message.content }] + } + + let functionCalls: FunctionCallPart[] = [] + let functionResponses: FunctionResponsePart[] = [] + let messageParts: Part[] = [] + + if ('tool_calls' in message && Array.isArray(message.tool_calls) && message.tool_calls.length > 0) { + functionCalls = message.tool_calls.map((tc) => ({ + functionCall: { + name: tc.name, + args: tc.args + } + })) + } else if (message._getType() === 'tool' && message.name && message.content) { + functionResponses = [ + { + functionResponse: { + name: message.name, + response: message.content + } + } + ] + } else if (Array.isArray(message.content)) { + messageParts = message.content.map((c) => { + if (c.type === 'text') { + return { + text: c.text + } + } + + if (c.type === 'image_url') { + if (!isMultimodalModel) { + throw new Error(`This model does not support images`) + } + let source + if (typeof c.image_url === 'string') { + source = c.image_url + } else if (typeof c.image_url === 'object' && 'url' in c.image_url) { + source = c.image_url.url + } else { + throw new Error('Please provide image as base64 encoded data URL') + } + const [dm, data] = source.split(',') + if (!dm.startsWith('data:')) { + throw new Error('Please provide image as base64 encoded data URL') + } + + const [mimeType, encoding] = dm.replace(/^data:/, '').split(';') + if (encoding !== 'base64') { + throw new Error('Please provide image as base64 encoded data URL') + } + + return { + inlineData: { + data, + mimeType + } + } + } else if (c.type === 'media') { + return messageContentMedia(c) + } else if (c.type === 'tool_use') { + return { + functionCall: { + name: c.name, + args: c.input + } + } + } + throw new Error(`Unknown content type ${(c as { type: string }).type}`) + }) + } + + return [...messageParts, ...functionCalls, ...functionResponses] +} + +// This is a dedicated logic for Multi Agent Supervisor to handle the case where the content is empty, and the role is the same +function checkIfEmptyContentAndSameRole(contents: Content[]) { + let prevRole = '' + const validContents: Content[] = [] + + for (const content of contents) { + // Skip only if completely empty + if (!content.parts || !content.parts.length) { + continue + } + + // Ensure role is always either 'user' or 'model' + content.role = content.role === 'model' ? 'model' : 'user' + + // Handle consecutive messages + if (content.role === prevRole && validContents.length > 0) { + // Merge with previous content if same role + validContents[validContents.length - 1].parts.push(...content.parts) + continue + } + + validContents.push(content) + prevRole = content.role + } + + return validContents +} + +function convertBaseMessagesToContent(messages: BaseMessage[], isMultimodalModel: boolean) { + return messages.reduce<{ + content: Content[] + mergeWithPreviousContent: boolean + }>( + (acc, message, index) => { + if (!isBaseMessage(message)) { + throw new Error('Unsupported message input') + } + const author = getMessageAuthor(message) + if (author === 'system' && index !== 0) { + throw new Error('System message should be the first one') + } + const role = convertAuthorToRole(author) + + const prevContent = acc.content[acc.content.length] + if (!acc.mergeWithPreviousContent && prevContent && prevContent.role === role) { + throw new Error('Google Generative AI requires alternate messages between authors') + } + + const parts = convertMessageContentToParts(message, isMultimodalModel) + + if (acc.mergeWithPreviousContent) { + const prevContent = acc.content[acc.content.length - 1] + if (!prevContent) { + throw new Error('There was a problem parsing your system message. Please try a prompt without one.') + } + prevContent.parts.push(...parts) + + return { + mergeWithPreviousContent: false, + content: acc.content + } + } + let actualRole = role + if (actualRole === 'function' || actualRole === 'tool') { + // GenerativeAI API will throw an error if the role is not "user" or "model." + actualRole = 'user' + } + const content: Content = { + role: actualRole, + parts + } + return { + mergeWithPreviousContent: author === 'system', + content: [...acc.content, content] + } + }, + { content: [], mergeWithPreviousContent: false } + ).content +} + +function mapGenerateContentResultToChatResult( + response: EnhancedGenerateContentResponse, + extra?: { + usageMetadata: UsageMetadata | undefined + } +): ChatResult { + // if rejected or error, return empty generations with reason in filters + if (!response.candidates || response.candidates.length === 0 || !response.candidates[0]) { + return { + generations: [], + llmOutput: { + filters: response.promptFeedback + } + } + } + + const functionCalls = response.functionCalls() + const [candidate] = response.candidates + const { content, ...generationInfo } = candidate + const text = content?.parts[0]?.text ?? '' + + const generation: ChatGeneration = { + text, + message: new AIMessage({ + content: text, + tool_calls: functionCalls, + additional_kwargs: { + ...generationInfo + }, + usage_metadata: extra?.usageMetadata as any + }), + generationInfo + } + + return { + generations: [generation] + } +} + +function convertResponseContentToChatGenerationChunk( + response: EnhancedGenerateContentResponse, + extra: { + usageMetadata?: UsageMetadata | undefined + index: number + } +): ChatGenerationChunk | null { + if (!response || !response.candidates || response.candidates.length === 0) { + return null + } + const functionCalls = response.functionCalls() + const [candidate] = response.candidates + const { content, ...generationInfo } = candidate + const text = content?.parts?.[0]?.text ?? '' + + const toolCallChunks: ToolCallChunk[] = [] + if (functionCalls) { + toolCallChunks.push( + ...functionCalls.map((fc) => ({ + ...fc, + args: JSON.stringify(fc.args), + index: extra.index + })) + ) + } + return new ChatGenerationChunk({ + text, + message: new AIMessageChunk({ + content: text, + name: !content ? undefined : content.role, + tool_call_chunks: toolCallChunks, + // Each chunk can have unique "generationInfo", and merging strategy is unclear, + // so leave blank for now. + additional_kwargs: {}, + usage_metadata: extra.usageMetadata as any + }), + generationInfo + }) +} + +function zodToGeminiParameters(zodObj: any) { + // Gemini doesn't accept either the $schema or additionalProperties + // attributes, so we need to explicitly remove them. + const jsonSchema: any = zodToJsonSchema(zodObj) + // eslint-disable-next-line unused-imports/no-unused-vars + const { $schema, additionalProperties, ...rest } = jsonSchema + + // Ensure all properties have type specified + if (rest.properties) { + Object.keys(rest.properties).forEach((key) => { + const prop = rest.properties[key] + + // Handle enum types + if (prop.enum?.length) { + rest.properties[key] = { + type: 'string', + format: 'enum', + enum: prop.enum + } + } + // Handle missing type + else if (!prop.type && !prop.oneOf && !prop.anyOf && !prop.allOf) { + // Infer type from other properties + if (prop.minimum !== undefined || prop.maximum !== undefined) { + prop.type = 'number' + } else if (prop.format === 'date-time') { + prop.type = 'string' + } else if (prop.items) { + prop.type = 'array' + } else if (prop.properties) { + prop.type = 'object' + } else { + // Default to string if type can't be inferred + prop.type = 'string' + } + } + }) + } + + return rest +} + +function convertToGeminiTools(structuredTools: (StructuredToolInterface | Record)[]) { + return [ + { + functionDeclarations: structuredTools.map((structuredTool) => { + if (isStructuredTool(structuredTool)) { + const jsonSchema = zodToGeminiParameters(structuredTool.schema) + return { + name: structuredTool.name, + description: structuredTool.description, + parameters: jsonSchema + } + } + return structuredTool + }) + } + ] +} +*/ diff --git a/packages/components/nodes/multiagents/Supervisor/Supervisor.ts b/packages/components/nodes/multiagents/Supervisor/Supervisor.ts index 1cd78eae0..f67abf00b 100644 --- a/packages/components/nodes/multiagents/Supervisor/Supervisor.ts +++ b/packages/components/nodes/multiagents/Supervisor/Supervisor.ts @@ -19,8 +19,8 @@ import { AgentExecutor, JsonOutputToolsParser, ToolCallingAgentOutputParser } fr import { ChatMistralAI } from '@langchain/mistralai' import { ChatOpenAI } from '../../chatmodels/ChatOpenAI/FlowiseChatOpenAI' import { ChatAnthropic } from '../../chatmodels/ChatAnthropic/FlowiseChatAnthropic' -import { ChatGoogleGenerativeAI } from '../../chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI' import { addImagesToMessages, llmSupportsVision } from '../../../src/multiModalUtils' +import { ChatGoogleGenerativeAI } from '../../chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI' const sysPrompt = `You are a supervisor tasked with managing a conversation between the following workers: {team_members}. Given the following user request, respond with the worker to act next. diff --git a/packages/components/nodes/retrievers/ExtractMetadataRetriever/ExtractMetadataRetriever.ts b/packages/components/nodes/retrievers/ExtractMetadataRetriever/ExtractMetadataRetriever.ts index 0df1b725d..1e6205f62 100644 --- a/packages/components/nodes/retrievers/ExtractMetadataRetriever/ExtractMetadataRetriever.ts +++ b/packages/components/nodes/retrievers/ExtractMetadataRetriever/ExtractMetadataRetriever.ts @@ -3,8 +3,7 @@ import { VectorStore, VectorStoreRetriever, VectorStoreRetrieverInput } from '@l import { INode, INodeData, INodeParams, INodeOutputsValue } from '../../../src/Interface' import { handleEscapeCharacters } from '../../../src' import { z } from 'zod' -import { convertStructuredSchemaToZod, ExtractTool } from '../../sequentialagents/commonUtils' -import { ChatGoogleGenerativeAI } from '@langchain/google-genai' +import { convertStructuredSchemaToZod } from '../../sequentialagents/commonUtils' const queryPrefix = 'query' const defaultPrompt = `Extract keywords from the query: {{${queryPrefix}}}` @@ -126,19 +125,8 @@ class ExtractMetadataRetriever_Retrievers implements INode { try { const structuredOutput = z.object(convertStructuredSchemaToZod(llmStructuredOutput)) - if (llm instanceof ChatGoogleGenerativeAI) { - const tool = new ExtractTool({ - schema: structuredOutput - }) - // @ts-ignore - const modelWithTool = llm.bind({ - tools: [tool] - }) as any - llm = modelWithTool - } else { - // @ts-ignore - llm = llm.withStructuredOutput(structuredOutput) - } + // @ts-ignore + llm = llm.withStructuredOutput(structuredOutput) } catch (exception) { console.error(exception) } diff --git a/packages/components/nodes/sequentialagents/ConditionAgent/ConditionAgent.ts b/packages/components/nodes/sequentialagents/ConditionAgent/ConditionAgent.ts index 3c1411c4f..1b0db13a2 100644 --- a/packages/components/nodes/sequentialagents/ConditionAgent/ConditionAgent.ts +++ b/packages/components/nodes/sequentialagents/ConditionAgent/ConditionAgent.ts @@ -18,7 +18,6 @@ import { } from '../../../src/Interface' import { getInputVariables, getVars, handleEscapeCharacters, prepareSandboxVars, transformBracesWithColon } from '../../../src/utils' import { - ExtractTool, checkCondition, convertStructuredSchemaToZod, customGet, @@ -27,7 +26,6 @@ import { filterConversationHistory, restructureMessages } from '../commonUtils' -import { ChatGoogleGenerativeAI } from '../../chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI' interface IConditionGridItem { variable: string @@ -485,20 +483,8 @@ const runCondition = async ( try { const structuredOutput = z.object(convertStructuredSchemaToZod(conditionAgentStructuredOutput)) - if (llm instanceof ChatGoogleGenerativeAI) { - const tool = new ExtractTool({ - schema: structuredOutput - }) - // @ts-ignore - const modelWithTool = llm.bind({ - tools: [tool], - signal: abortControllerSignal ? abortControllerSignal.signal : undefined - }) - model = modelWithTool - } else { - // @ts-ignore - model = llm.withStructuredOutput(structuredOutput) - } + // @ts-ignore + model = llm.withStructuredOutput(structuredOutput) } catch (exception) { console.error('Invalid JSON in Condition Agent Structured Output: ' + exception) model = llm diff --git a/packages/components/nodes/sequentialagents/LLMNode/LLMNode.ts b/packages/components/nodes/sequentialagents/LLMNode/LLMNode.ts index f2c31ac90..3a4edb0be 100644 --- a/packages/components/nodes/sequentialagents/LLMNode/LLMNode.ts +++ b/packages/components/nodes/sequentialagents/LLMNode/LLMNode.ts @@ -27,7 +27,6 @@ import { transformBracesWithColon } from '../../../src/utils' import { - ExtractTool, convertStructuredSchemaToZod, customGet, getVM, @@ -37,7 +36,6 @@ import { restructureMessages, checkMessageHistory } from '../commonUtils' -import { ChatGoogleGenerativeAI } from '../../chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI' const TAB_IDENTIFIER = 'selectedUpdateStateMemoryTab' const customOutputFuncDesc = `This is only applicable when you have a custom State at the START node. After agent execution, you might want to update the State values` @@ -513,19 +511,8 @@ async function createAgent( try { const structuredOutput = z.object(convertStructuredSchemaToZod(llmStructuredOutput)) - if (llm instanceof ChatGoogleGenerativeAI) { - const tool = new ExtractTool({ - schema: structuredOutput - }) - // @ts-ignore - const modelWithTool = llm.bind({ - tools: [tool] - }) as any - llm = modelWithTool - } else { - // @ts-ignore - llm = llm.withStructuredOutput(structuredOutput) - } + // @ts-ignore + llm = llm.withStructuredOutput(structuredOutput) } catch (exception) { console.error(exception) } diff --git a/packages/components/src/followUpPrompts.ts b/packages/components/src/followUpPrompts.ts index cc19864f6..4e8a732ba 100644 --- a/packages/components/src/followUpPrompts.ts +++ b/packages/components/src/followUpPrompts.ts @@ -71,24 +71,13 @@ export const generateFollowUpPrompts = async ( return structuredResponse } case FollowUpPromptProvider.GOOGLE_GENAI: { - const llm = new ChatGoogleGenerativeAI({ + const model = new ChatGoogleGenerativeAI({ apiKey: credentialData.googleGenerativeAPIKey, model: providerConfig.modelName, temperature: parseFloat(`${providerConfig.temperature}`) }) - // use structured output parser because withStructuredOutput is not working - const parser = StructuredOutputParser.fromZodSchema(FollowUpPromptType) - const formatInstructions = parser.getFormatInstructions() - const prompt = PromptTemplate.fromTemplate(` - ${providerConfig.prompt} - - {format_instructions} - `) - const chain = prompt.pipe(llm).pipe(parser) - const structuredResponse = await chain.invoke({ - history: apiMessageContent, - format_instructions: formatInstructions - }) + const structuredLLM = model.withStructuredOutput(FollowUpPromptType) + const structuredResponse = await structuredLLM.invoke(followUpPromptsPrompt) return structuredResponse } case FollowUpPromptProvider.MISTRALAI: {