Bugfix/Gemini Structured Output (#4713)
* fix gemini structured output * update issues templates
This commit is contained in:
parent
f50a817bf4
commit
2e42dfb635
|
|
@ -1,6 +1,6 @@
|
||||||
name: Bug Report
|
name: Bug Report
|
||||||
description: File a bug report to help us improve
|
description: File a bug report to help us improve
|
||||||
title: '[BUG] '
|
title: ''
|
||||||
labels: ['bug']
|
labels: ['bug']
|
||||||
assignees: []
|
assignees: []
|
||||||
body:
|
body:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
name: Feature Request
|
name: Feature Request
|
||||||
description: Suggest a new feature or enhancement for Flowise
|
description: Suggest a new feature or enhancement for Flowise
|
||||||
title: '[FEATURE] '
|
title: ''
|
||||||
labels: ['enhancement']
|
labels: ['enhancement']
|
||||||
assignees: []
|
assignees: []
|
||||||
body:
|
body:
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ import { BaseCache } from '@langchain/core/caches'
|
||||||
import { ICommonObject, IMultiModalOption, INode, INodeData, INodeOptionsValue, INodeParams } from '../../../src/Interface'
|
import { ICommonObject, IMultiModalOption, INode, INodeData, INodeOptionsValue, INodeParams } from '../../../src/Interface'
|
||||||
import { convertMultiOptionsToStringArray, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
|
import { convertMultiOptionsToStringArray, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
|
||||||
import { getModels, MODEL_TYPE } from '../../../src/modelLoader'
|
import { getModels, MODEL_TYPE } from '../../../src/modelLoader'
|
||||||
import { ChatGoogleGenerativeAI, GoogleGenerativeAIChatInput } from './FlowiseChatGoogleGenerativeAI'
|
import { ChatGoogleGenerativeAI } from './FlowiseChatGoogleGenerativeAI'
|
||||||
import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager'
|
import { GoogleGenerativeAIChatInput } from '@langchain/google-genai'
|
||||||
|
|
||||||
class GoogleGenerativeAI_ChatModels implements INode {
|
class GoogleGenerativeAI_ChatModels implements INode {
|
||||||
label: string
|
label: string
|
||||||
|
|
@ -43,12 +43,6 @@ class GoogleGenerativeAI_ChatModels implements INode {
|
||||||
type: 'BaseCache',
|
type: 'BaseCache',
|
||||||
optional: true
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
|
||||||
label: 'Context Cache',
|
|
||||||
name: 'contextCache',
|
|
||||||
type: 'GoogleAICacheManager',
|
|
||||||
optional: true
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
label: 'Model Name',
|
label: 'Model Name',
|
||||||
name: 'modelName',
|
name: 'modelName',
|
||||||
|
|
@ -204,15 +198,14 @@ class GoogleGenerativeAI_ChatModels implements INode {
|
||||||
const harmCategory = nodeData.inputs?.harmCategory as string
|
const harmCategory = nodeData.inputs?.harmCategory as string
|
||||||
const harmBlockThreshold = nodeData.inputs?.harmBlockThreshold as string
|
const harmBlockThreshold = nodeData.inputs?.harmBlockThreshold as string
|
||||||
const cache = nodeData.inputs?.cache as BaseCache
|
const cache = nodeData.inputs?.cache as BaseCache
|
||||||
const contextCache = nodeData.inputs?.contextCache as FlowiseGoogleAICacheManager
|
|
||||||
const streaming = nodeData.inputs?.streaming as boolean
|
const streaming = nodeData.inputs?.streaming as boolean
|
||||||
const baseUrl = nodeData.inputs?.baseUrl as string | undefined
|
const baseUrl = nodeData.inputs?.baseUrl as string | undefined
|
||||||
|
|
||||||
const allowImageUploads = nodeData.inputs?.allowImageUploads as boolean
|
const allowImageUploads = nodeData.inputs?.allowImageUploads as boolean
|
||||||
|
|
||||||
const obj: Partial<GoogleGenerativeAIChatInput> = {
|
const obj: GoogleGenerativeAIChatInput = {
|
||||||
apiKey: apiKey,
|
apiKey: apiKey,
|
||||||
modelName: customModelName || modelName,
|
model: customModelName || modelName,
|
||||||
streaming: streaming ?? true
|
streaming: streaming ?? true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -248,7 +241,6 @@ class GoogleGenerativeAI_ChatModels implements INode {
|
||||||
|
|
||||||
const model = new ChatGoogleGenerativeAI(nodeData.id, obj)
|
const model = new ChatGoogleGenerativeAI(nodeData.id, obj)
|
||||||
model.setMultiModalOption(multiModalOption)
|
model.setMultiModalOption(multiModalOption)
|
||||||
if (contextCache) model.setContextCache(contextCache)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,374 +1,5 @@
|
||||||
import { BaseMessage, AIMessage, AIMessageChunk, isBaseMessage, ChatMessage, MessageContentComplex } from '@langchain/core/messages'
|
import { ChatGoogleGenerativeAI as LangchainChatGoogleGenerativeAI, GoogleGenerativeAIChatInput } from '@langchain/google-genai'
|
||||||
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
|
import { IMultiModalOption, IVisionChatModal } from '../../../src'
|
||||||
import { BaseChatModel, type BaseChatModelParams } from '@langchain/core/language_models/chat_models'
|
|
||||||
import { ChatGeneration, ChatGenerationChunk, ChatResult } from '@langchain/core/outputs'
|
|
||||||
import { ToolCallChunk } from '@langchain/core/messages/tool'
|
|
||||||
import { NewTokenIndices } from '@langchain/core/callbacks/base'
|
|
||||||
import {
|
|
||||||
EnhancedGenerateContentResponse,
|
|
||||||
Content,
|
|
||||||
Part,
|
|
||||||
Tool,
|
|
||||||
GenerativeModel,
|
|
||||||
GoogleGenerativeAI as GenerativeAI
|
|
||||||
} from '@google/generative-ai'
|
|
||||||
import type {
|
|
||||||
FunctionCallPart,
|
|
||||||
FunctionResponsePart,
|
|
||||||
SafetySetting,
|
|
||||||
UsageMetadata,
|
|
||||||
FunctionDeclarationsTool as GoogleGenerativeAIFunctionDeclarationsTool,
|
|
||||||
GenerateContentRequest
|
|
||||||
} from '@google/generative-ai'
|
|
||||||
import { ICommonObject, IMultiModalOption, IVisionChatModal } from '../../../src'
|
|
||||||
import { StructuredToolInterface } from '@langchain/core/tools'
|
|
||||||
import { isStructuredTool } from '@langchain/core/utils/function_calling'
|
|
||||||
import { zodToJsonSchema } from 'zod-to-json-schema'
|
|
||||||
import { BaseLanguageModelCallOptions } from '@langchain/core/language_models/base'
|
|
||||||
import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager'
|
|
||||||
|
|
||||||
const DEFAULT_IMAGE_MAX_TOKEN = 8192
|
|
||||||
const DEFAULT_IMAGE_MODEL = 'gemini-1.5-flash-latest'
|
|
||||||
|
|
||||||
interface TokenUsage {
|
|
||||||
completionTokens?: number
|
|
||||||
promptTokens?: number
|
|
||||||
totalTokens?: number
|
|
||||||
}
|
|
||||||
|
|
||||||
interface GoogleGenerativeAIChatCallOptions extends BaseLanguageModelCallOptions {
|
|
||||||
tools?: StructuredToolInterface[] | GoogleGenerativeAIFunctionDeclarationsTool[]
|
|
||||||
/**
|
|
||||||
* Whether or not to include usage data, like token counts
|
|
||||||
* in the streamed response chunks.
|
|
||||||
* @default true
|
|
||||||
*/
|
|
||||||
streamUsage?: boolean
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface GoogleGenerativeAIChatInput extends BaseChatModelParams, Pick<GoogleGenerativeAIChatCallOptions, 'streamUsage'> {
|
|
||||||
modelName?: string
|
|
||||||
model?: string
|
|
||||||
temperature?: number
|
|
||||||
maxOutputTokens?: number
|
|
||||||
topP?: number
|
|
||||||
topK?: number
|
|
||||||
stopSequences?: string[]
|
|
||||||
safetySettings?: SafetySetting[]
|
|
||||||
apiKey?: string
|
|
||||||
apiVersion?: string
|
|
||||||
baseUrl?: string
|
|
||||||
streaming?: boolean
|
|
||||||
}
|
|
||||||
|
|
||||||
class LangchainChatGoogleGenerativeAI
|
|
||||||
extends BaseChatModel<GoogleGenerativeAIChatCallOptions, AIMessageChunk>
|
|
||||||
implements GoogleGenerativeAIChatInput
|
|
||||||
{
|
|
||||||
modelName = 'gemini-pro'
|
|
||||||
|
|
||||||
temperature?: number
|
|
||||||
|
|
||||||
maxOutputTokens?: number
|
|
||||||
|
|
||||||
topP?: number
|
|
||||||
|
|
||||||
topK?: number
|
|
||||||
|
|
||||||
stopSequences: string[] = []
|
|
||||||
|
|
||||||
safetySettings?: SafetySetting[]
|
|
||||||
|
|
||||||
apiKey?: string
|
|
||||||
|
|
||||||
baseUrl?: string
|
|
||||||
|
|
||||||
streaming = false
|
|
||||||
|
|
||||||
streamUsage = true
|
|
||||||
|
|
||||||
private client: GenerativeModel
|
|
||||||
|
|
||||||
private contextCache?: FlowiseGoogleAICacheManager
|
|
||||||
|
|
||||||
get _isMultimodalModel() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
constructor(fields?: GoogleGenerativeAIChatInput) {
|
|
||||||
super(fields ?? {})
|
|
||||||
|
|
||||||
this.modelName = fields?.model?.replace(/^models\//, '') ?? fields?.modelName?.replace(/^models\//, '') ?? 'gemini-pro'
|
|
||||||
|
|
||||||
this.maxOutputTokens = fields?.maxOutputTokens ?? this.maxOutputTokens
|
|
||||||
|
|
||||||
if (this.maxOutputTokens && this.maxOutputTokens < 0) {
|
|
||||||
throw new Error('`maxOutputTokens` must be a positive integer')
|
|
||||||
}
|
|
||||||
|
|
||||||
this.temperature = fields?.temperature ?? this.temperature
|
|
||||||
if (this.temperature && (this.temperature < 0 || this.temperature > 1)) {
|
|
||||||
throw new Error('`temperature` must be in the range of [0.0,1.0]')
|
|
||||||
}
|
|
||||||
|
|
||||||
this.topP = fields?.topP ?? this.topP
|
|
||||||
if (this.topP && this.topP < 0) {
|
|
||||||
throw new Error('`topP` must be a positive integer')
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.topP && this.topP > 1) {
|
|
||||||
throw new Error('`topP` must be below 1.')
|
|
||||||
}
|
|
||||||
|
|
||||||
this.topK = fields?.topK ?? this.topK
|
|
||||||
if (this.topK && this.topK < 0) {
|
|
||||||
throw new Error('`topK` must be a positive integer')
|
|
||||||
}
|
|
||||||
|
|
||||||
this.stopSequences = fields?.stopSequences ?? this.stopSequences
|
|
||||||
|
|
||||||
this.apiKey = fields?.apiKey ?? process.env['GOOGLE_API_KEY']
|
|
||||||
if (!this.apiKey) {
|
|
||||||
throw new Error(
|
|
||||||
'Please set an API key for Google GenerativeAI ' +
|
|
||||||
'in the environment variable GOOGLE_API_KEY ' +
|
|
||||||
'or in the `apiKey` field of the ' +
|
|
||||||
'ChatGoogleGenerativeAI constructor'
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.safetySettings = fields?.safetySettings ?? this.safetySettings
|
|
||||||
if (this.safetySettings && this.safetySettings.length > 0) {
|
|
||||||
const safetySettingsSet = new Set(this.safetySettings.map((s) => s.category))
|
|
||||||
if (safetySettingsSet.size !== this.safetySettings.length) {
|
|
||||||
throw new Error('The categories in `safetySettings` array must be unique')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this.streaming = fields?.streaming ?? this.streaming
|
|
||||||
|
|
||||||
this.streamUsage = fields?.streamUsage ?? this.streamUsage
|
|
||||||
|
|
||||||
this.getClient()
|
|
||||||
}
|
|
||||||
|
|
||||||
async getClient(prompt?: Content[], tools?: Tool[]) {
|
|
||||||
this.client = new GenerativeAI(this.apiKey ?? '').getGenerativeModel(
|
|
||||||
{
|
|
||||||
model: this.modelName,
|
|
||||||
tools,
|
|
||||||
safetySettings: this.safetySettings as SafetySetting[],
|
|
||||||
generationConfig: {
|
|
||||||
candidateCount: 1,
|
|
||||||
stopSequences: this.stopSequences,
|
|
||||||
maxOutputTokens: this.maxOutputTokens,
|
|
||||||
temperature: this.temperature,
|
|
||||||
topP: this.topP,
|
|
||||||
topK: this.topK
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
baseUrl: this.baseUrl
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if (this.contextCache) {
|
|
||||||
const cachedContent = await this.contextCache.lookup({
|
|
||||||
contents: prompt ? [{ ...prompt[0], parts: prompt[0].parts.slice(0, 1) }] : [],
|
|
||||||
model: this.modelName,
|
|
||||||
tools
|
|
||||||
})
|
|
||||||
this.client.cachedContent = cachedContent as any
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_combineLLMOutput() {
|
|
||||||
return []
|
|
||||||
}
|
|
||||||
|
|
||||||
_llmType() {
|
|
||||||
return 'googlegenerativeai'
|
|
||||||
}
|
|
||||||
|
|
||||||
override bindTools(tools: (StructuredToolInterface | Record<string, unknown>)[], kwargs?: Partial<ICommonObject>) {
|
|
||||||
//@ts-ignore
|
|
||||||
return this.bind({ tools: convertToGeminiTools(tools), ...kwargs })
|
|
||||||
}
|
|
||||||
|
|
||||||
invocationParams(options?: this['ParsedCallOptions']): Omit<GenerateContentRequest, 'contents'> {
|
|
||||||
const tools = options?.tools as GoogleGenerativeAIFunctionDeclarationsTool[] | StructuredToolInterface[] | undefined
|
|
||||||
if (Array.isArray(tools) && !tools.some((t: any) => !('lc_namespace' in t))) {
|
|
||||||
return {
|
|
||||||
tools: convertToGeminiTools(options?.tools as StructuredToolInterface[]) as any
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
tools: options?.tools as GoogleGenerativeAIFunctionDeclarationsTool[] | undefined
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
convertFunctionResponse(prompts: Content[]) {
|
|
||||||
for (let i = 0; i < prompts.length; i += 1) {
|
|
||||||
if (prompts[i].role === 'function') {
|
|
||||||
if (prompts[i - 1].role === 'model') {
|
|
||||||
const toolName = prompts[i - 1].parts[0].functionCall?.name ?? ''
|
|
||||||
prompts[i].parts = [
|
|
||||||
{
|
|
||||||
functionResponse: {
|
|
||||||
name: toolName,
|
|
||||||
response: {
|
|
||||||
name: toolName,
|
|
||||||
content: prompts[i].parts[0].text
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
setContextCache(contextCache: FlowiseGoogleAICacheManager): void {
|
|
||||||
this.contextCache = contextCache
|
|
||||||
}
|
|
||||||
|
|
||||||
async getNumTokens(prompt: BaseMessage[]) {
|
|
||||||
const contents = convertBaseMessagesToContent(prompt, this._isMultimodalModel)
|
|
||||||
const { totalTokens } = await this.client.countTokens({ contents })
|
|
||||||
return totalTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
async _generateNonStreaming(
|
|
||||||
prompt: Content[],
|
|
||||||
options: this['ParsedCallOptions'],
|
|
||||||
_runManager?: CallbackManagerForLLMRun
|
|
||||||
): Promise<ChatResult> {
|
|
||||||
//@ts-ignore
|
|
||||||
const tools = options.tools ?? []
|
|
||||||
|
|
||||||
this.convertFunctionResponse(prompt)
|
|
||||||
|
|
||||||
if (tools.length > 0) {
|
|
||||||
await this.getClient(prompt, tools as Tool[])
|
|
||||||
} else {
|
|
||||||
await this.getClient(prompt)
|
|
||||||
}
|
|
||||||
const res = await this.caller.callWithOptions({ signal: options?.signal }, async () => {
|
|
||||||
let output
|
|
||||||
try {
|
|
||||||
output = await this.client.generateContent({
|
|
||||||
contents: prompt
|
|
||||||
})
|
|
||||||
} catch (e: any) {
|
|
||||||
if (e.message?.includes('400 Bad Request')) {
|
|
||||||
e.status = 400
|
|
||||||
}
|
|
||||||
throw e
|
|
||||||
}
|
|
||||||
return output
|
|
||||||
})
|
|
||||||
const generationResult = mapGenerateContentResultToChatResult(res.response)
|
|
||||||
await _runManager?.handleLLMNewToken(generationResult.generations?.length ? generationResult.generations[0].text : '')
|
|
||||||
return generationResult
|
|
||||||
}
|
|
||||||
|
|
||||||
async _generate(
|
|
||||||
messages: BaseMessage[],
|
|
||||||
options: this['ParsedCallOptions'],
|
|
||||||
runManager?: CallbackManagerForLLMRun
|
|
||||||
): Promise<ChatResult> {
|
|
||||||
let prompt = convertBaseMessagesToContent(messages, this._isMultimodalModel)
|
|
||||||
prompt = checkIfEmptyContentAndSameRole(prompt)
|
|
||||||
|
|
||||||
// Handle streaming
|
|
||||||
if (this.streaming) {
|
|
||||||
const tokenUsage: TokenUsage = {}
|
|
||||||
const stream = this._streamResponseChunks(messages, options, runManager)
|
|
||||||
const finalChunks: Record<number, ChatGenerationChunk> = {}
|
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
|
||||||
const index = (chunk.generationInfo as NewTokenIndices)?.completion ?? 0
|
|
||||||
if (finalChunks[index] === undefined) {
|
|
||||||
finalChunks[index] = chunk
|
|
||||||
} else {
|
|
||||||
finalChunks[index] = finalChunks[index].concat(chunk)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const generations = Object.entries(finalChunks)
|
|
||||||
.sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10))
|
|
||||||
.map(([_, value]) => value)
|
|
||||||
|
|
||||||
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } }
|
|
||||||
}
|
|
||||||
return this._generateNonStreaming(prompt, options, runManager)
|
|
||||||
}
|
|
||||||
|
|
||||||
async *_streamResponseChunks(
|
|
||||||
messages: BaseMessage[],
|
|
||||||
options: this['ParsedCallOptions'],
|
|
||||||
runManager?: CallbackManagerForLLMRun
|
|
||||||
): AsyncGenerator<ChatGenerationChunk> {
|
|
||||||
let prompt = convertBaseMessagesToContent(messages, this._isMultimodalModel)
|
|
||||||
prompt = checkIfEmptyContentAndSameRole(prompt)
|
|
||||||
|
|
||||||
const parameters = this.invocationParams(options)
|
|
||||||
const request = {
|
|
||||||
...parameters,
|
|
||||||
contents: prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
const tools = options.tools ?? []
|
|
||||||
if (tools.length > 0) {
|
|
||||||
await this.getClient(prompt, tools as Tool[])
|
|
||||||
} else {
|
|
||||||
await this.getClient(prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => {
|
|
||||||
const { stream } = await this.client.generateContentStream(request)
|
|
||||||
return stream
|
|
||||||
})
|
|
||||||
|
|
||||||
let usageMetadata: UsageMetadata | ICommonObject | undefined
|
|
||||||
let index = 0
|
|
||||||
for await (const response of stream) {
|
|
||||||
if ('usageMetadata' in response && this.streamUsage !== false && options.streamUsage !== false) {
|
|
||||||
const genAIUsageMetadata = response.usageMetadata as {
|
|
||||||
promptTokenCount: number
|
|
||||||
candidatesTokenCount: number
|
|
||||||
totalTokenCount: number
|
|
||||||
}
|
|
||||||
if (!usageMetadata) {
|
|
||||||
usageMetadata = {
|
|
||||||
input_tokens: genAIUsageMetadata.promptTokenCount,
|
|
||||||
output_tokens: genAIUsageMetadata.candidatesTokenCount,
|
|
||||||
total_tokens: genAIUsageMetadata.totalTokenCount
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Under the hood, LangChain combines the prompt tokens. Google returns the updated
|
|
||||||
// total each time, so we need to find the difference between the tokens.
|
|
||||||
const outputTokenDiff = genAIUsageMetadata.candidatesTokenCount - (usageMetadata as ICommonObject).output_tokens
|
|
||||||
usageMetadata = {
|
|
||||||
input_tokens: 0,
|
|
||||||
output_tokens: outputTokenDiff,
|
|
||||||
total_tokens: outputTokenDiff
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const chunk = convertResponseContentToChatGenerationChunk(response, {
|
|
||||||
usageMetadata: usageMetadata as UsageMetadata,
|
|
||||||
index
|
|
||||||
})
|
|
||||||
index += 1
|
|
||||||
if (!chunk) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
yield chunk
|
|
||||||
await runManager?.handleLLMNewToken(chunk.text ?? '')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export class ChatGoogleGenerativeAI extends LangchainChatGoogleGenerativeAI implements IVisionChatModal {
|
export class ChatGoogleGenerativeAI extends LangchainChatGoogleGenerativeAI implements IVisionChatModal {
|
||||||
configuredModel: string
|
configuredModel: string
|
||||||
|
|
@ -376,15 +7,15 @@ export class ChatGoogleGenerativeAI extends LangchainChatGoogleGenerativeAI impl
|
||||||
multiModalOption: IMultiModalOption
|
multiModalOption: IMultiModalOption
|
||||||
id: string
|
id: string
|
||||||
|
|
||||||
constructor(id: string, fields?: GoogleGenerativeAIChatInput) {
|
constructor(id: string, fields: GoogleGenerativeAIChatInput) {
|
||||||
super(fields)
|
super(fields)
|
||||||
this.id = id
|
this.id = id
|
||||||
this.configuredModel = fields?.modelName ?? ''
|
this.configuredModel = fields?.model ?? ''
|
||||||
this.configuredMaxToken = fields?.maxOutputTokens
|
this.configuredMaxToken = fields?.maxOutputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
revertToOriginalModel(): void {
|
revertToOriginalModel(): void {
|
||||||
this.modelName = this.configuredModel
|
this.model = this.configuredModel
|
||||||
this.maxOutputTokens = this.configuredMaxToken
|
this.maxOutputTokens = this.configuredMaxToken
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -393,346 +24,6 @@ export class ChatGoogleGenerativeAI extends LangchainChatGoogleGenerativeAI impl
|
||||||
}
|
}
|
||||||
|
|
||||||
setVisionModel(): void {
|
setVisionModel(): void {
|
||||||
if (this.modelName === 'gemini-1.0-pro-latest') {
|
// pass
|
||||||
this.modelName = DEFAULT_IMAGE_MODEL
|
|
||||||
this.maxOutputTokens = this.configuredMaxToken ? this.configuredMaxToken : DEFAULT_IMAGE_MAX_TOKEN
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
function messageContentMedia(content: MessageContentComplex): Part {
|
|
||||||
if ('mimeType' in content && 'data' in content) {
|
|
||||||
return {
|
|
||||||
inlineData: {
|
|
||||||
mimeType: content.mimeType,
|
|
||||||
data: content.data
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
throw new Error('Invalid media content')
|
|
||||||
}
|
|
||||||
|
|
||||||
function getMessageAuthor(message: BaseMessage) {
|
|
||||||
const type = message._getType()
|
|
||||||
if (ChatMessage.isInstance(message)) {
|
|
||||||
return message.role
|
|
||||||
}
|
|
||||||
return message.name ?? type
|
|
||||||
}
|
|
||||||
|
|
||||||
function convertAuthorToRole(author: string) {
|
|
||||||
switch (author.toLowerCase()) {
|
|
||||||
case 'ai':
|
|
||||||
case 'assistant':
|
|
||||||
case 'model':
|
|
||||||
return 'model'
|
|
||||||
case 'function':
|
|
||||||
case 'tool':
|
|
||||||
return 'function'
|
|
||||||
case 'system':
|
|
||||||
case 'human':
|
|
||||||
default:
|
|
||||||
return 'user'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function convertMessageContentToParts(message: BaseMessage, isMultimodalModel: boolean): Part[] {
|
|
||||||
if (typeof message.content === 'string' && message.content !== '') {
|
|
||||||
return [{ text: message.content }]
|
|
||||||
}
|
|
||||||
|
|
||||||
let functionCalls: FunctionCallPart[] = []
|
|
||||||
let functionResponses: FunctionResponsePart[] = []
|
|
||||||
let messageParts: Part[] = []
|
|
||||||
|
|
||||||
if ('tool_calls' in message && Array.isArray(message.tool_calls) && message.tool_calls.length > 0) {
|
|
||||||
functionCalls = message.tool_calls.map((tc) => ({
|
|
||||||
functionCall: {
|
|
||||||
name: tc.name,
|
|
||||||
args: tc.args
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
} else if (message._getType() === 'tool' && message.name && message.content) {
|
|
||||||
functionResponses = [
|
|
||||||
{
|
|
||||||
functionResponse: {
|
|
||||||
name: message.name,
|
|
||||||
response: message.content
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
} else if (Array.isArray(message.content)) {
|
|
||||||
messageParts = message.content.map((c) => {
|
|
||||||
if (c.type === 'text') {
|
|
||||||
return {
|
|
||||||
text: c.text
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (c.type === 'image_url') {
|
|
||||||
if (!isMultimodalModel) {
|
|
||||||
throw new Error(`This model does not support images`)
|
|
||||||
}
|
|
||||||
let source
|
|
||||||
if (typeof c.image_url === 'string') {
|
|
||||||
source = c.image_url
|
|
||||||
} else if (typeof c.image_url === 'object' && 'url' in c.image_url) {
|
|
||||||
source = c.image_url.url
|
|
||||||
} else {
|
|
||||||
throw new Error('Please provide image as base64 encoded data URL')
|
|
||||||
}
|
|
||||||
const [dm, data] = source.split(',')
|
|
||||||
if (!dm.startsWith('data:')) {
|
|
||||||
throw new Error('Please provide image as base64 encoded data URL')
|
|
||||||
}
|
|
||||||
|
|
||||||
const [mimeType, encoding] = dm.replace(/^data:/, '').split(';')
|
|
||||||
if (encoding !== 'base64') {
|
|
||||||
throw new Error('Please provide image as base64 encoded data URL')
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
inlineData: {
|
|
||||||
data,
|
|
||||||
mimeType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (c.type === 'media') {
|
|
||||||
return messageContentMedia(c)
|
|
||||||
} else if (c.type === 'tool_use') {
|
|
||||||
return {
|
|
||||||
functionCall: {
|
|
||||||
name: c.name,
|
|
||||||
args: c.input
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
throw new Error(`Unknown content type ${(c as { type: string }).type}`)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return [...messageParts, ...functionCalls, ...functionResponses]
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* This is a dedicated logic for Multi Agent Supervisor to handle the case where the content is empty, and the role is the same
|
|
||||||
*/
|
|
||||||
|
|
||||||
function checkIfEmptyContentAndSameRole(contents: Content[]) {
|
|
||||||
let prevRole = ''
|
|
||||||
const validContents: Content[] = []
|
|
||||||
|
|
||||||
for (const content of contents) {
|
|
||||||
// Skip only if completely empty
|
|
||||||
if (!content.parts || !content.parts.length) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure role is always either 'user' or 'model'
|
|
||||||
content.role = content.role === 'model' ? 'model' : 'user'
|
|
||||||
|
|
||||||
// Handle consecutive messages
|
|
||||||
if (content.role === prevRole && validContents.length > 0) {
|
|
||||||
// Merge with previous content if same role
|
|
||||||
validContents[validContents.length - 1].parts.push(...content.parts)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
validContents.push(content)
|
|
||||||
prevRole = content.role
|
|
||||||
}
|
|
||||||
|
|
||||||
return validContents
|
|
||||||
}
|
|
||||||
|
|
||||||
function convertBaseMessagesToContent(messages: BaseMessage[], isMultimodalModel: boolean) {
|
|
||||||
return messages.reduce<{
|
|
||||||
content: Content[]
|
|
||||||
mergeWithPreviousContent: boolean
|
|
||||||
}>(
|
|
||||||
(acc, message, index) => {
|
|
||||||
if (!isBaseMessage(message)) {
|
|
||||||
throw new Error('Unsupported message input')
|
|
||||||
}
|
|
||||||
const author = getMessageAuthor(message)
|
|
||||||
if (author === 'system' && index !== 0) {
|
|
||||||
throw new Error('System message should be the first one')
|
|
||||||
}
|
|
||||||
const role = convertAuthorToRole(author)
|
|
||||||
|
|
||||||
const prevContent = acc.content[acc.content.length]
|
|
||||||
if (!acc.mergeWithPreviousContent && prevContent && prevContent.role === role) {
|
|
||||||
throw new Error('Google Generative AI requires alternate messages between authors')
|
|
||||||
}
|
|
||||||
|
|
||||||
const parts = convertMessageContentToParts(message, isMultimodalModel)
|
|
||||||
|
|
||||||
if (acc.mergeWithPreviousContent) {
|
|
||||||
const prevContent = acc.content[acc.content.length - 1]
|
|
||||||
if (!prevContent) {
|
|
||||||
throw new Error('There was a problem parsing your system message. Please try a prompt without one.')
|
|
||||||
}
|
|
||||||
prevContent.parts.push(...parts)
|
|
||||||
|
|
||||||
return {
|
|
||||||
mergeWithPreviousContent: false,
|
|
||||||
content: acc.content
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let actualRole = role
|
|
||||||
if (actualRole === 'function' || actualRole === 'tool') {
|
|
||||||
// GenerativeAI API will throw an error if the role is not "user" or "model."
|
|
||||||
actualRole = 'user'
|
|
||||||
}
|
|
||||||
const content: Content = {
|
|
||||||
role: actualRole,
|
|
||||||
parts
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
mergeWithPreviousContent: author === 'system',
|
|
||||||
content: [...acc.content, content]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{ content: [], mergeWithPreviousContent: false }
|
|
||||||
).content
|
|
||||||
}
|
|
||||||
|
|
||||||
function mapGenerateContentResultToChatResult(
|
|
||||||
response: EnhancedGenerateContentResponse,
|
|
||||||
extra?: {
|
|
||||||
usageMetadata: UsageMetadata | undefined
|
|
||||||
}
|
|
||||||
): ChatResult {
|
|
||||||
// if rejected or error, return empty generations with reason in filters
|
|
||||||
if (!response.candidates || response.candidates.length === 0 || !response.candidates[0]) {
|
|
||||||
return {
|
|
||||||
generations: [],
|
|
||||||
llmOutput: {
|
|
||||||
filters: response.promptFeedback
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const functionCalls = response.functionCalls()
|
|
||||||
const [candidate] = response.candidates
|
|
||||||
const { content, ...generationInfo } = candidate
|
|
||||||
const text = content?.parts[0]?.text ?? ''
|
|
||||||
|
|
||||||
const generation: ChatGeneration = {
|
|
||||||
text,
|
|
||||||
message: new AIMessage({
|
|
||||||
content: text,
|
|
||||||
tool_calls: functionCalls,
|
|
||||||
additional_kwargs: {
|
|
||||||
...generationInfo
|
|
||||||
},
|
|
||||||
usage_metadata: extra?.usageMetadata as any
|
|
||||||
}),
|
|
||||||
generationInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
generations: [generation]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function convertResponseContentToChatGenerationChunk(
|
|
||||||
response: EnhancedGenerateContentResponse,
|
|
||||||
extra: {
|
|
||||||
usageMetadata?: UsageMetadata | undefined
|
|
||||||
index: number
|
|
||||||
}
|
|
||||||
): ChatGenerationChunk | null {
|
|
||||||
if (!response || !response.candidates || response.candidates.length === 0) {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
const functionCalls = response.functionCalls()
|
|
||||||
const [candidate] = response.candidates
|
|
||||||
const { content, ...generationInfo } = candidate
|
|
||||||
const text = content?.parts?.[0]?.text ?? ''
|
|
||||||
|
|
||||||
const toolCallChunks: ToolCallChunk[] = []
|
|
||||||
if (functionCalls) {
|
|
||||||
toolCallChunks.push(
|
|
||||||
...functionCalls.map((fc) => ({
|
|
||||||
...fc,
|
|
||||||
args: JSON.stringify(fc.args),
|
|
||||||
index: extra.index
|
|
||||||
}))
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return new ChatGenerationChunk({
|
|
||||||
text,
|
|
||||||
message: new AIMessageChunk({
|
|
||||||
content: text,
|
|
||||||
name: !content ? undefined : content.role,
|
|
||||||
tool_call_chunks: toolCallChunks,
|
|
||||||
// Each chunk can have unique "generationInfo", and merging strategy is unclear,
|
|
||||||
// so leave blank for now.
|
|
||||||
additional_kwargs: {},
|
|
||||||
usage_metadata: extra.usageMetadata as any
|
|
||||||
}),
|
|
||||||
generationInfo
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
function zodToGeminiParameters(zodObj: any) {
|
|
||||||
// Gemini doesn't accept either the $schema or additionalProperties
|
|
||||||
// attributes, so we need to explicitly remove them.
|
|
||||||
const jsonSchema: any = zodToJsonSchema(zodObj)
|
|
||||||
// eslint-disable-next-line unused-imports/no-unused-vars
|
|
||||||
const { $schema, additionalProperties, ...rest } = jsonSchema
|
|
||||||
|
|
||||||
// Ensure all properties have type specified
|
|
||||||
if (rest.properties) {
|
|
||||||
Object.keys(rest.properties).forEach((key) => {
|
|
||||||
const prop = rest.properties[key]
|
|
||||||
|
|
||||||
// Handle enum types
|
|
||||||
if (prop.enum?.length) {
|
|
||||||
rest.properties[key] = {
|
|
||||||
type: 'string',
|
|
||||||
format: 'enum',
|
|
||||||
enum: prop.enum
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Handle missing type
|
|
||||||
else if (!prop.type && !prop.oneOf && !prop.anyOf && !prop.allOf) {
|
|
||||||
// Infer type from other properties
|
|
||||||
if (prop.minimum !== undefined || prop.maximum !== undefined) {
|
|
||||||
prop.type = 'number'
|
|
||||||
} else if (prop.format === 'date-time') {
|
|
||||||
prop.type = 'string'
|
|
||||||
} else if (prop.items) {
|
|
||||||
prop.type = 'array'
|
|
||||||
} else if (prop.properties) {
|
|
||||||
prop.type = 'object'
|
|
||||||
} else {
|
|
||||||
// Default to string if type can't be inferred
|
|
||||||
prop.type = 'string'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return rest
|
|
||||||
}
|
|
||||||
|
|
||||||
function convertToGeminiTools(structuredTools: (StructuredToolInterface | Record<string, unknown>)[]) {
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
functionDeclarations: structuredTools.map((structuredTool) => {
|
|
||||||
if (isStructuredTool(structuredTool)) {
|
|
||||||
const jsonSchema = zodToGeminiParameters(structuredTool.schema)
|
|
||||||
return {
|
|
||||||
name: structuredTool.name,
|
|
||||||
description: structuredTool.description,
|
|
||||||
parameters: jsonSchema
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return structuredTool
|
|
||||||
})
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,733 @@
|
||||||
|
/** Disabled due to the withStructuredOutput
|
||||||
|
|
||||||
|
import { BaseMessage, AIMessage, AIMessageChunk, isBaseMessage, ChatMessage, MessageContentComplex } from '@langchain/core/messages'
|
||||||
|
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
|
||||||
|
import { BaseChatModel, type BaseChatModelParams } from '@langchain/core/language_models/chat_models'
|
||||||
|
import { ChatGeneration, ChatGenerationChunk, ChatResult } from '@langchain/core/outputs'
|
||||||
|
import { ToolCallChunk } from '@langchain/core/messages/tool'
|
||||||
|
import { NewTokenIndices } from '@langchain/core/callbacks/base'
|
||||||
|
import {
|
||||||
|
EnhancedGenerateContentResponse,
|
||||||
|
Content,
|
||||||
|
Part,
|
||||||
|
Tool,
|
||||||
|
GenerativeModel,
|
||||||
|
GoogleGenerativeAI as GenerativeAI
|
||||||
|
} from '@google/generative-ai'
|
||||||
|
import type {
|
||||||
|
FunctionCallPart,
|
||||||
|
FunctionResponsePart,
|
||||||
|
SafetySetting,
|
||||||
|
UsageMetadata,
|
||||||
|
FunctionDeclarationsTool as GoogleGenerativeAIFunctionDeclarationsTool,
|
||||||
|
GenerateContentRequest
|
||||||
|
} from '@google/generative-ai'
|
||||||
|
import { ICommonObject, IMultiModalOption, IVisionChatModal } from '../../../src'
|
||||||
|
import { StructuredToolInterface } from '@langchain/core/tools'
|
||||||
|
import { isStructuredTool } from '@langchain/core/utils/function_calling'
|
||||||
|
import { zodToJsonSchema } from 'zod-to-json-schema'
|
||||||
|
import { BaseLanguageModelCallOptions } from '@langchain/core/language_models/base'
|
||||||
|
import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager'
|
||||||
|
|
||||||
|
const DEFAULT_IMAGE_MAX_TOKEN = 8192
|
||||||
|
const DEFAULT_IMAGE_MODEL = 'gemini-1.5-flash-latest'
|
||||||
|
|
||||||
|
interface TokenUsage {
|
||||||
|
completionTokens?: number
|
||||||
|
promptTokens?: number
|
||||||
|
totalTokens?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
interface GoogleGenerativeAIChatCallOptions extends BaseLanguageModelCallOptions {
|
||||||
|
tools?: StructuredToolInterface[] | GoogleGenerativeAIFunctionDeclarationsTool[]
|
||||||
|
streamUsage?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GoogleGenerativeAIChatInput extends BaseChatModelParams, Pick<GoogleGenerativeAIChatCallOptions, 'streamUsage'> {
|
||||||
|
modelName?: string
|
||||||
|
model?: string
|
||||||
|
temperature?: number
|
||||||
|
maxOutputTokens?: number
|
||||||
|
topP?: number
|
||||||
|
topK?: number
|
||||||
|
stopSequences?: string[]
|
||||||
|
safetySettings?: SafetySetting[]
|
||||||
|
apiKey?: string
|
||||||
|
apiVersion?: string
|
||||||
|
baseUrl?: string
|
||||||
|
streaming?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
class LangchainChatGoogleGenerativeAI
|
||||||
|
extends BaseChatModel<GoogleGenerativeAIChatCallOptions, AIMessageChunk>
|
||||||
|
implements GoogleGenerativeAIChatInput
|
||||||
|
{
|
||||||
|
modelName = 'gemini-pro'
|
||||||
|
|
||||||
|
temperature?: number
|
||||||
|
|
||||||
|
maxOutputTokens?: number
|
||||||
|
|
||||||
|
topP?: number
|
||||||
|
|
||||||
|
topK?: number
|
||||||
|
|
||||||
|
stopSequences: string[] = []
|
||||||
|
|
||||||
|
safetySettings?: SafetySetting[]
|
||||||
|
|
||||||
|
apiKey?: string
|
||||||
|
|
||||||
|
baseUrl?: string
|
||||||
|
|
||||||
|
streaming = false
|
||||||
|
|
||||||
|
streamUsage = true
|
||||||
|
|
||||||
|
private client: GenerativeModel
|
||||||
|
|
||||||
|
private contextCache?: FlowiseGoogleAICacheManager
|
||||||
|
|
||||||
|
get _isMultimodalModel() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
constructor(fields?: GoogleGenerativeAIChatInput) {
|
||||||
|
super(fields ?? {})
|
||||||
|
|
||||||
|
this.modelName = fields?.model?.replace(/^models\//, '') ?? fields?.modelName?.replace(/^models\//, '') ?? 'gemini-pro'
|
||||||
|
|
||||||
|
this.maxOutputTokens = fields?.maxOutputTokens ?? this.maxOutputTokens
|
||||||
|
|
||||||
|
if (this.maxOutputTokens && this.maxOutputTokens < 0) {
|
||||||
|
throw new Error('`maxOutputTokens` must be a positive integer')
|
||||||
|
}
|
||||||
|
|
||||||
|
this.temperature = fields?.temperature ?? this.temperature
|
||||||
|
if (this.temperature && (this.temperature < 0 || this.temperature > 1)) {
|
||||||
|
throw new Error('`temperature` must be in the range of [0.0,1.0]')
|
||||||
|
}
|
||||||
|
|
||||||
|
this.topP = fields?.topP ?? this.topP
|
||||||
|
if (this.topP && this.topP < 0) {
|
||||||
|
throw new Error('`topP` must be a positive integer')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.topP && this.topP > 1) {
|
||||||
|
throw new Error('`topP` must be below 1.')
|
||||||
|
}
|
||||||
|
|
||||||
|
this.topK = fields?.topK ?? this.topK
|
||||||
|
if (this.topK && this.topK < 0) {
|
||||||
|
throw new Error('`topK` must be a positive integer')
|
||||||
|
}
|
||||||
|
|
||||||
|
this.stopSequences = fields?.stopSequences ?? this.stopSequences
|
||||||
|
|
||||||
|
this.apiKey = fields?.apiKey ?? process.env['GOOGLE_API_KEY']
|
||||||
|
if (!this.apiKey) {
|
||||||
|
throw new Error(
|
||||||
|
'Please set an API key for Google GenerativeAI ' +
|
||||||
|
'in the environment variable GOOGLE_API_KEY ' +
|
||||||
|
'or in the `apiKey` field of the ' +
|
||||||
|
'ChatGoogleGenerativeAI constructor'
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
this.safetySettings = fields?.safetySettings ?? this.safetySettings
|
||||||
|
if (this.safetySettings && this.safetySettings.length > 0) {
|
||||||
|
const safetySettingsSet = new Set(this.safetySettings.map((s) => s.category))
|
||||||
|
if (safetySettingsSet.size !== this.safetySettings.length) {
|
||||||
|
throw new Error('The categories in `safetySettings` array must be unique')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.streaming = fields?.streaming ?? this.streaming
|
||||||
|
|
||||||
|
this.streamUsage = fields?.streamUsage ?? this.streamUsage
|
||||||
|
|
||||||
|
this.getClient()
|
||||||
|
}
|
||||||
|
|
||||||
|
async getClient(prompt?: Content[], tools?: Tool[]) {
|
||||||
|
this.client = new GenerativeAI(this.apiKey ?? '').getGenerativeModel(
|
||||||
|
{
|
||||||
|
model: this.modelName,
|
||||||
|
tools,
|
||||||
|
safetySettings: this.safetySettings as SafetySetting[],
|
||||||
|
generationConfig: {
|
||||||
|
candidateCount: 1,
|
||||||
|
stopSequences: this.stopSequences,
|
||||||
|
maxOutputTokens: this.maxOutputTokens,
|
||||||
|
temperature: this.temperature,
|
||||||
|
topP: this.topP,
|
||||||
|
topK: this.topK
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
baseUrl: this.baseUrl
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if (this.contextCache) {
|
||||||
|
const cachedContent = await this.contextCache.lookup({
|
||||||
|
contents: prompt ? [{ ...prompt[0], parts: prompt[0].parts.slice(0, 1) }] : [],
|
||||||
|
model: this.modelName,
|
||||||
|
tools
|
||||||
|
})
|
||||||
|
this.client.cachedContent = cachedContent as any
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_combineLLMOutput() {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
_llmType() {
|
||||||
|
return 'googlegenerativeai'
|
||||||
|
}
|
||||||
|
|
||||||
|
override bindTools(tools: (StructuredToolInterface | Record<string, unknown>)[], kwargs?: Partial<ICommonObject>) {
|
||||||
|
//@ts-ignore
|
||||||
|
return this.bind({ tools: convertToGeminiTools(tools), ...kwargs })
|
||||||
|
}
|
||||||
|
|
||||||
|
invocationParams(options?: this['ParsedCallOptions']): Omit<GenerateContentRequest, 'contents'> {
|
||||||
|
const tools = options?.tools as GoogleGenerativeAIFunctionDeclarationsTool[] | StructuredToolInterface[] | undefined
|
||||||
|
if (Array.isArray(tools) && !tools.some((t: any) => !('lc_namespace' in t))) {
|
||||||
|
return {
|
||||||
|
tools: convertToGeminiTools(options?.tools as StructuredToolInterface[]) as any
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
tools: options?.tools as GoogleGenerativeAIFunctionDeclarationsTool[] | undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
convertFunctionResponse(prompts: Content[]) {
|
||||||
|
for (let i = 0; i < prompts.length; i += 1) {
|
||||||
|
if (prompts[i].role === 'function') {
|
||||||
|
if (prompts[i - 1].role === 'model') {
|
||||||
|
const toolName = prompts[i - 1].parts[0].functionCall?.name ?? ''
|
||||||
|
prompts[i].parts = [
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
name: toolName,
|
||||||
|
response: {
|
||||||
|
name: toolName,
|
||||||
|
content: prompts[i].parts[0].text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setContextCache(contextCache: FlowiseGoogleAICacheManager): void {
|
||||||
|
this.contextCache = contextCache
|
||||||
|
}
|
||||||
|
|
||||||
|
async getNumTokens(prompt: BaseMessage[]) {
|
||||||
|
const contents = convertBaseMessagesToContent(prompt, this._isMultimodalModel)
|
||||||
|
const { totalTokens } = await this.client.countTokens({ contents })
|
||||||
|
return totalTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
async _generateNonStreaming(
|
||||||
|
prompt: Content[],
|
||||||
|
options: this['ParsedCallOptions'],
|
||||||
|
_runManager?: CallbackManagerForLLMRun
|
||||||
|
): Promise<ChatResult> {
|
||||||
|
//@ts-ignore
|
||||||
|
const tools = options.tools ?? []
|
||||||
|
|
||||||
|
this.convertFunctionResponse(prompt)
|
||||||
|
|
||||||
|
if (tools.length > 0) {
|
||||||
|
await this.getClient(prompt, tools as Tool[])
|
||||||
|
} else {
|
||||||
|
await this.getClient(prompt)
|
||||||
|
}
|
||||||
|
const res = await this.caller.callWithOptions({ signal: options?.signal }, async () => {
|
||||||
|
let output
|
||||||
|
try {
|
||||||
|
output = await this.client.generateContent({
|
||||||
|
contents: prompt
|
||||||
|
})
|
||||||
|
} catch (e: any) {
|
||||||
|
if (e.message?.includes('400 Bad Request')) {
|
||||||
|
e.status = 400
|
||||||
|
}
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
})
|
||||||
|
const generationResult = mapGenerateContentResultToChatResult(res.response)
|
||||||
|
await _runManager?.handleLLMNewToken(generationResult.generations?.length ? generationResult.generations[0].text : '')
|
||||||
|
return generationResult
|
||||||
|
}
|
||||||
|
|
||||||
|
async _generate(
|
||||||
|
messages: BaseMessage[],
|
||||||
|
options: this['ParsedCallOptions'],
|
||||||
|
runManager?: CallbackManagerForLLMRun
|
||||||
|
): Promise<ChatResult> {
|
||||||
|
let prompt = convertBaseMessagesToContent(messages, this._isMultimodalModel)
|
||||||
|
prompt = checkIfEmptyContentAndSameRole(prompt)
|
||||||
|
|
||||||
|
// Handle streaming
|
||||||
|
if (this.streaming) {
|
||||||
|
const tokenUsage: TokenUsage = {}
|
||||||
|
const stream = this._streamResponseChunks(messages, options, runManager)
|
||||||
|
const finalChunks: Record<number, ChatGenerationChunk> = {}
|
||||||
|
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
const index = (chunk.generationInfo as NewTokenIndices)?.completion ?? 0
|
||||||
|
if (finalChunks[index] === undefined) {
|
||||||
|
finalChunks[index] = chunk
|
||||||
|
} else {
|
||||||
|
finalChunks[index] = finalChunks[index].concat(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const generations = Object.entries(finalChunks)
|
||||||
|
.sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10))
|
||||||
|
.map(([_, value]) => value)
|
||||||
|
|
||||||
|
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } }
|
||||||
|
}
|
||||||
|
return this._generateNonStreaming(prompt, options, runManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
async *_streamResponseChunks(
|
||||||
|
messages: BaseMessage[],
|
||||||
|
options: this['ParsedCallOptions'],
|
||||||
|
runManager?: CallbackManagerForLLMRun
|
||||||
|
): AsyncGenerator<ChatGenerationChunk> {
|
||||||
|
let prompt = convertBaseMessagesToContent(messages, this._isMultimodalModel)
|
||||||
|
prompt = checkIfEmptyContentAndSameRole(prompt)
|
||||||
|
|
||||||
|
const parameters = this.invocationParams(options)
|
||||||
|
const request = {
|
||||||
|
...parameters,
|
||||||
|
contents: prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
const tools = options.tools ?? []
|
||||||
|
if (tools.length > 0) {
|
||||||
|
await this.getClient(prompt, tools as Tool[])
|
||||||
|
} else {
|
||||||
|
await this.getClient(prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => {
|
||||||
|
const { stream } = await this.client.generateContentStream(request)
|
||||||
|
return stream
|
||||||
|
})
|
||||||
|
|
||||||
|
let usageMetadata: UsageMetadata | ICommonObject | undefined
|
||||||
|
let index = 0
|
||||||
|
for await (const response of stream) {
|
||||||
|
if ('usageMetadata' in response && this.streamUsage !== false && options.streamUsage !== false) {
|
||||||
|
const genAIUsageMetadata = response.usageMetadata as {
|
||||||
|
promptTokenCount: number
|
||||||
|
candidatesTokenCount: number
|
||||||
|
totalTokenCount: number
|
||||||
|
}
|
||||||
|
if (!usageMetadata) {
|
||||||
|
usageMetadata = {
|
||||||
|
input_tokens: genAIUsageMetadata.promptTokenCount,
|
||||||
|
output_tokens: genAIUsageMetadata.candidatesTokenCount,
|
||||||
|
total_tokens: genAIUsageMetadata.totalTokenCount
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Under the hood, LangChain combines the prompt tokens. Google returns the updated
|
||||||
|
// total each time, so we need to find the difference between the tokens.
|
||||||
|
const outputTokenDiff = genAIUsageMetadata.candidatesTokenCount - (usageMetadata as ICommonObject).output_tokens
|
||||||
|
usageMetadata = {
|
||||||
|
input_tokens: 0,
|
||||||
|
output_tokens: outputTokenDiff,
|
||||||
|
total_tokens: outputTokenDiff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const chunk = convertResponseContentToChatGenerationChunk(response, {
|
||||||
|
usageMetadata: usageMetadata as UsageMetadata,
|
||||||
|
index
|
||||||
|
})
|
||||||
|
index += 1
|
||||||
|
if (!chunk) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
await runManager?.handleLLMNewToken(chunk.text ?? '')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ChatGoogleGenerativeAI extends LangchainChatGoogleGenerativeAI implements IVisionChatModal {
|
||||||
|
configuredModel: string
|
||||||
|
configuredMaxToken?: number
|
||||||
|
multiModalOption: IMultiModalOption
|
||||||
|
id: string
|
||||||
|
|
||||||
|
constructor(id: string, fields?: GoogleGenerativeAIChatInput) {
|
||||||
|
super(fields)
|
||||||
|
this.id = id
|
||||||
|
this.configuredModel = fields?.modelName ?? ''
|
||||||
|
this.configuredMaxToken = fields?.maxOutputTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
revertToOriginalModel(): void {
|
||||||
|
this.modelName = this.configuredModel
|
||||||
|
this.maxOutputTokens = this.configuredMaxToken
|
||||||
|
}
|
||||||
|
|
||||||
|
setMultiModalOption(multiModalOption: IMultiModalOption): void {
|
||||||
|
this.multiModalOption = multiModalOption
|
||||||
|
}
|
||||||
|
|
||||||
|
setVisionModel(): void {
|
||||||
|
if (this.modelName === 'gemini-1.0-pro-latest') {
|
||||||
|
this.modelName = DEFAULT_IMAGE_MODEL
|
||||||
|
this.maxOutputTokens = this.configuredMaxToken ? this.configuredMaxToken : DEFAULT_IMAGE_MAX_TOKEN
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function messageContentMedia(content: MessageContentComplex): Part {
|
||||||
|
if ('mimeType' in content && 'data' in content) {
|
||||||
|
return {
|
||||||
|
inlineData: {
|
||||||
|
mimeType: content.mimeType,
|
||||||
|
data: content.data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error('Invalid media content')
|
||||||
|
}
|
||||||
|
|
||||||
|
function getMessageAuthor(message: BaseMessage) {
|
||||||
|
const type = message._getType()
|
||||||
|
if (ChatMessage.isInstance(message)) {
|
||||||
|
return message.role
|
||||||
|
}
|
||||||
|
return message.name ?? type
|
||||||
|
}
|
||||||
|
|
||||||
|
function convertAuthorToRole(author: string) {
|
||||||
|
switch (author.toLowerCase()) {
|
||||||
|
case 'ai':
|
||||||
|
case 'assistant':
|
||||||
|
case 'model':
|
||||||
|
return 'model'
|
||||||
|
case 'function':
|
||||||
|
case 'tool':
|
||||||
|
return 'function'
|
||||||
|
case 'system':
|
||||||
|
case 'human':
|
||||||
|
default:
|
||||||
|
return 'user'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function convertMessageContentToParts(message: BaseMessage, isMultimodalModel: boolean): Part[] {
|
||||||
|
if (typeof message.content === 'string' && message.content !== '') {
|
||||||
|
return [{ text: message.content }]
|
||||||
|
}
|
||||||
|
|
||||||
|
let functionCalls: FunctionCallPart[] = []
|
||||||
|
let functionResponses: FunctionResponsePart[] = []
|
||||||
|
let messageParts: Part[] = []
|
||||||
|
|
||||||
|
if ('tool_calls' in message && Array.isArray(message.tool_calls) && message.tool_calls.length > 0) {
|
||||||
|
functionCalls = message.tool_calls.map((tc) => ({
|
||||||
|
functionCall: {
|
||||||
|
name: tc.name,
|
||||||
|
args: tc.args
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
} else if (message._getType() === 'tool' && message.name && message.content) {
|
||||||
|
functionResponses = [
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
name: message.name,
|
||||||
|
response: message.content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
} else if (Array.isArray(message.content)) {
|
||||||
|
messageParts = message.content.map((c) => {
|
||||||
|
if (c.type === 'text') {
|
||||||
|
return {
|
||||||
|
text: c.text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (c.type === 'image_url') {
|
||||||
|
if (!isMultimodalModel) {
|
||||||
|
throw new Error(`This model does not support images`)
|
||||||
|
}
|
||||||
|
let source
|
||||||
|
if (typeof c.image_url === 'string') {
|
||||||
|
source = c.image_url
|
||||||
|
} else if (typeof c.image_url === 'object' && 'url' in c.image_url) {
|
||||||
|
source = c.image_url.url
|
||||||
|
} else {
|
||||||
|
throw new Error('Please provide image as base64 encoded data URL')
|
||||||
|
}
|
||||||
|
const [dm, data] = source.split(',')
|
||||||
|
if (!dm.startsWith('data:')) {
|
||||||
|
throw new Error('Please provide image as base64 encoded data URL')
|
||||||
|
}
|
||||||
|
|
||||||
|
const [mimeType, encoding] = dm.replace(/^data:/, '').split(';')
|
||||||
|
if (encoding !== 'base64') {
|
||||||
|
throw new Error('Please provide image as base64 encoded data URL')
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
inlineData: {
|
||||||
|
data,
|
||||||
|
mimeType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (c.type === 'media') {
|
||||||
|
return messageContentMedia(c)
|
||||||
|
} else if (c.type === 'tool_use') {
|
||||||
|
return {
|
||||||
|
functionCall: {
|
||||||
|
name: c.name,
|
||||||
|
args: c.input
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new Error(`Unknown content type ${(c as { type: string }).type}`)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return [...messageParts, ...functionCalls, ...functionResponses]
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is a dedicated logic for Multi Agent Supervisor to handle the case where the content is empty, and the role is the same
|
||||||
|
function checkIfEmptyContentAndSameRole(contents: Content[]) {
|
||||||
|
let prevRole = ''
|
||||||
|
const validContents: Content[] = []
|
||||||
|
|
||||||
|
for (const content of contents) {
|
||||||
|
// Skip only if completely empty
|
||||||
|
if (!content.parts || !content.parts.length) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure role is always either 'user' or 'model'
|
||||||
|
content.role = content.role === 'model' ? 'model' : 'user'
|
||||||
|
|
||||||
|
// Handle consecutive messages
|
||||||
|
if (content.role === prevRole && validContents.length > 0) {
|
||||||
|
// Merge with previous content if same role
|
||||||
|
validContents[validContents.length - 1].parts.push(...content.parts)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
validContents.push(content)
|
||||||
|
prevRole = content.role
|
||||||
|
}
|
||||||
|
|
||||||
|
return validContents
|
||||||
|
}
|
||||||
|
|
||||||
|
function convertBaseMessagesToContent(messages: BaseMessage[], isMultimodalModel: boolean) {
|
||||||
|
return messages.reduce<{
|
||||||
|
content: Content[]
|
||||||
|
mergeWithPreviousContent: boolean
|
||||||
|
}>(
|
||||||
|
(acc, message, index) => {
|
||||||
|
if (!isBaseMessage(message)) {
|
||||||
|
throw new Error('Unsupported message input')
|
||||||
|
}
|
||||||
|
const author = getMessageAuthor(message)
|
||||||
|
if (author === 'system' && index !== 0) {
|
||||||
|
throw new Error('System message should be the first one')
|
||||||
|
}
|
||||||
|
const role = convertAuthorToRole(author)
|
||||||
|
|
||||||
|
const prevContent = acc.content[acc.content.length]
|
||||||
|
if (!acc.mergeWithPreviousContent && prevContent && prevContent.role === role) {
|
||||||
|
throw new Error('Google Generative AI requires alternate messages between authors')
|
||||||
|
}
|
||||||
|
|
||||||
|
const parts = convertMessageContentToParts(message, isMultimodalModel)
|
||||||
|
|
||||||
|
if (acc.mergeWithPreviousContent) {
|
||||||
|
const prevContent = acc.content[acc.content.length - 1]
|
||||||
|
if (!prevContent) {
|
||||||
|
throw new Error('There was a problem parsing your system message. Please try a prompt without one.')
|
||||||
|
}
|
||||||
|
prevContent.parts.push(...parts)
|
||||||
|
|
||||||
|
return {
|
||||||
|
mergeWithPreviousContent: false,
|
||||||
|
content: acc.content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let actualRole = role
|
||||||
|
if (actualRole === 'function' || actualRole === 'tool') {
|
||||||
|
// GenerativeAI API will throw an error if the role is not "user" or "model."
|
||||||
|
actualRole = 'user'
|
||||||
|
}
|
||||||
|
const content: Content = {
|
||||||
|
role: actualRole,
|
||||||
|
parts
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
mergeWithPreviousContent: author === 'system',
|
||||||
|
content: [...acc.content, content]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ content: [], mergeWithPreviousContent: false }
|
||||||
|
).content
|
||||||
|
}
|
||||||
|
|
||||||
|
function mapGenerateContentResultToChatResult(
|
||||||
|
response: EnhancedGenerateContentResponse,
|
||||||
|
extra?: {
|
||||||
|
usageMetadata: UsageMetadata | undefined
|
||||||
|
}
|
||||||
|
): ChatResult {
|
||||||
|
// if rejected or error, return empty generations with reason in filters
|
||||||
|
if (!response.candidates || response.candidates.length === 0 || !response.candidates[0]) {
|
||||||
|
return {
|
||||||
|
generations: [],
|
||||||
|
llmOutput: {
|
||||||
|
filters: response.promptFeedback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const functionCalls = response.functionCalls()
|
||||||
|
const [candidate] = response.candidates
|
||||||
|
const { content, ...generationInfo } = candidate
|
||||||
|
const text = content?.parts[0]?.text ?? ''
|
||||||
|
|
||||||
|
const generation: ChatGeneration = {
|
||||||
|
text,
|
||||||
|
message: new AIMessage({
|
||||||
|
content: text,
|
||||||
|
tool_calls: functionCalls,
|
||||||
|
additional_kwargs: {
|
||||||
|
...generationInfo
|
||||||
|
},
|
||||||
|
usage_metadata: extra?.usageMetadata as any
|
||||||
|
}),
|
||||||
|
generationInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
generations: [generation]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function convertResponseContentToChatGenerationChunk(
|
||||||
|
response: EnhancedGenerateContentResponse,
|
||||||
|
extra: {
|
||||||
|
usageMetadata?: UsageMetadata | undefined
|
||||||
|
index: number
|
||||||
|
}
|
||||||
|
): ChatGenerationChunk | null {
|
||||||
|
if (!response || !response.candidates || response.candidates.length === 0) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
const functionCalls = response.functionCalls()
|
||||||
|
const [candidate] = response.candidates
|
||||||
|
const { content, ...generationInfo } = candidate
|
||||||
|
const text = content?.parts?.[0]?.text ?? ''
|
||||||
|
|
||||||
|
const toolCallChunks: ToolCallChunk[] = []
|
||||||
|
if (functionCalls) {
|
||||||
|
toolCallChunks.push(
|
||||||
|
...functionCalls.map((fc) => ({
|
||||||
|
...fc,
|
||||||
|
args: JSON.stringify(fc.args),
|
||||||
|
index: extra.index
|
||||||
|
}))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return new ChatGenerationChunk({
|
||||||
|
text,
|
||||||
|
message: new AIMessageChunk({
|
||||||
|
content: text,
|
||||||
|
name: !content ? undefined : content.role,
|
||||||
|
tool_call_chunks: toolCallChunks,
|
||||||
|
// Each chunk can have unique "generationInfo", and merging strategy is unclear,
|
||||||
|
// so leave blank for now.
|
||||||
|
additional_kwargs: {},
|
||||||
|
usage_metadata: extra.usageMetadata as any
|
||||||
|
}),
|
||||||
|
generationInfo
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function zodToGeminiParameters(zodObj: any) {
|
||||||
|
// Gemini doesn't accept either the $schema or additionalProperties
|
||||||
|
// attributes, so we need to explicitly remove them.
|
||||||
|
const jsonSchema: any = zodToJsonSchema(zodObj)
|
||||||
|
// eslint-disable-next-line unused-imports/no-unused-vars
|
||||||
|
const { $schema, additionalProperties, ...rest } = jsonSchema
|
||||||
|
|
||||||
|
// Ensure all properties have type specified
|
||||||
|
if (rest.properties) {
|
||||||
|
Object.keys(rest.properties).forEach((key) => {
|
||||||
|
const prop = rest.properties[key]
|
||||||
|
|
||||||
|
// Handle enum types
|
||||||
|
if (prop.enum?.length) {
|
||||||
|
rest.properties[key] = {
|
||||||
|
type: 'string',
|
||||||
|
format: 'enum',
|
||||||
|
enum: prop.enum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Handle missing type
|
||||||
|
else if (!prop.type && !prop.oneOf && !prop.anyOf && !prop.allOf) {
|
||||||
|
// Infer type from other properties
|
||||||
|
if (prop.minimum !== undefined || prop.maximum !== undefined) {
|
||||||
|
prop.type = 'number'
|
||||||
|
} else if (prop.format === 'date-time') {
|
||||||
|
prop.type = 'string'
|
||||||
|
} else if (prop.items) {
|
||||||
|
prop.type = 'array'
|
||||||
|
} else if (prop.properties) {
|
||||||
|
prop.type = 'object'
|
||||||
|
} else {
|
||||||
|
// Default to string if type can't be inferred
|
||||||
|
prop.type = 'string'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return rest
|
||||||
|
}
|
||||||
|
|
||||||
|
function convertToGeminiTools(structuredTools: (StructuredToolInterface | Record<string, unknown>)[]) {
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
functionDeclarations: structuredTools.map((structuredTool) => {
|
||||||
|
if (isStructuredTool(structuredTool)) {
|
||||||
|
const jsonSchema = zodToGeminiParameters(structuredTool.schema)
|
||||||
|
return {
|
||||||
|
name: structuredTool.name,
|
||||||
|
description: structuredTool.description,
|
||||||
|
parameters: jsonSchema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return structuredTool
|
||||||
|
})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
@ -19,8 +19,8 @@ import { AgentExecutor, JsonOutputToolsParser, ToolCallingAgentOutputParser } fr
|
||||||
import { ChatMistralAI } from '@langchain/mistralai'
|
import { ChatMistralAI } from '@langchain/mistralai'
|
||||||
import { ChatOpenAI } from '../../chatmodels/ChatOpenAI/FlowiseChatOpenAI'
|
import { ChatOpenAI } from '../../chatmodels/ChatOpenAI/FlowiseChatOpenAI'
|
||||||
import { ChatAnthropic } from '../../chatmodels/ChatAnthropic/FlowiseChatAnthropic'
|
import { ChatAnthropic } from '../../chatmodels/ChatAnthropic/FlowiseChatAnthropic'
|
||||||
import { ChatGoogleGenerativeAI } from '../../chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI'
|
|
||||||
import { addImagesToMessages, llmSupportsVision } from '../../../src/multiModalUtils'
|
import { addImagesToMessages, llmSupportsVision } from '../../../src/multiModalUtils'
|
||||||
|
import { ChatGoogleGenerativeAI } from '../../chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI'
|
||||||
|
|
||||||
const sysPrompt = `You are a supervisor tasked with managing a conversation between the following workers: {team_members}.
|
const sysPrompt = `You are a supervisor tasked with managing a conversation between the following workers: {team_members}.
|
||||||
Given the following user request, respond with the worker to act next.
|
Given the following user request, respond with the worker to act next.
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,7 @@ import { VectorStore, VectorStoreRetriever, VectorStoreRetrieverInput } from '@l
|
||||||
import { INode, INodeData, INodeParams, INodeOutputsValue } from '../../../src/Interface'
|
import { INode, INodeData, INodeParams, INodeOutputsValue } from '../../../src/Interface'
|
||||||
import { handleEscapeCharacters } from '../../../src'
|
import { handleEscapeCharacters } from '../../../src'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
import { convertStructuredSchemaToZod, ExtractTool } from '../../sequentialagents/commonUtils'
|
import { convertStructuredSchemaToZod } from '../../sequentialagents/commonUtils'
|
||||||
import { ChatGoogleGenerativeAI } from '@langchain/google-genai'
|
|
||||||
|
|
||||||
const queryPrefix = 'query'
|
const queryPrefix = 'query'
|
||||||
const defaultPrompt = `Extract keywords from the query: {{${queryPrefix}}}`
|
const defaultPrompt = `Extract keywords from the query: {{${queryPrefix}}}`
|
||||||
|
|
@ -126,19 +125,8 @@ class ExtractMetadataRetriever_Retrievers implements INode {
|
||||||
try {
|
try {
|
||||||
const structuredOutput = z.object(convertStructuredSchemaToZod(llmStructuredOutput))
|
const structuredOutput = z.object(convertStructuredSchemaToZod(llmStructuredOutput))
|
||||||
|
|
||||||
if (llm instanceof ChatGoogleGenerativeAI) {
|
|
||||||
const tool = new ExtractTool({
|
|
||||||
schema: structuredOutput
|
|
||||||
})
|
|
||||||
// @ts-ignore
|
|
||||||
const modelWithTool = llm.bind({
|
|
||||||
tools: [tool]
|
|
||||||
}) as any
|
|
||||||
llm = modelWithTool
|
|
||||||
} else {
|
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
llm = llm.withStructuredOutput(structuredOutput)
|
llm = llm.withStructuredOutput(structuredOutput)
|
||||||
}
|
|
||||||
} catch (exception) {
|
} catch (exception) {
|
||||||
console.error(exception)
|
console.error(exception)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ import {
|
||||||
} from '../../../src/Interface'
|
} from '../../../src/Interface'
|
||||||
import { getInputVariables, getVars, handleEscapeCharacters, prepareSandboxVars, transformBracesWithColon } from '../../../src/utils'
|
import { getInputVariables, getVars, handleEscapeCharacters, prepareSandboxVars, transformBracesWithColon } from '../../../src/utils'
|
||||||
import {
|
import {
|
||||||
ExtractTool,
|
|
||||||
checkCondition,
|
checkCondition,
|
||||||
convertStructuredSchemaToZod,
|
convertStructuredSchemaToZod,
|
||||||
customGet,
|
customGet,
|
||||||
|
|
@ -27,7 +26,6 @@ import {
|
||||||
filterConversationHistory,
|
filterConversationHistory,
|
||||||
restructureMessages
|
restructureMessages
|
||||||
} from '../commonUtils'
|
} from '../commonUtils'
|
||||||
import { ChatGoogleGenerativeAI } from '../../chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI'
|
|
||||||
|
|
||||||
interface IConditionGridItem {
|
interface IConditionGridItem {
|
||||||
variable: string
|
variable: string
|
||||||
|
|
@ -485,20 +483,8 @@ const runCondition = async (
|
||||||
try {
|
try {
|
||||||
const structuredOutput = z.object(convertStructuredSchemaToZod(conditionAgentStructuredOutput))
|
const structuredOutput = z.object(convertStructuredSchemaToZod(conditionAgentStructuredOutput))
|
||||||
|
|
||||||
if (llm instanceof ChatGoogleGenerativeAI) {
|
|
||||||
const tool = new ExtractTool({
|
|
||||||
schema: structuredOutput
|
|
||||||
})
|
|
||||||
// @ts-ignore
|
|
||||||
const modelWithTool = llm.bind({
|
|
||||||
tools: [tool],
|
|
||||||
signal: abortControllerSignal ? abortControllerSignal.signal : undefined
|
|
||||||
})
|
|
||||||
model = modelWithTool
|
|
||||||
} else {
|
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
model = llm.withStructuredOutput(structuredOutput)
|
model = llm.withStructuredOutput(structuredOutput)
|
||||||
}
|
|
||||||
} catch (exception) {
|
} catch (exception) {
|
||||||
console.error('Invalid JSON in Condition Agent Structured Output: ' + exception)
|
console.error('Invalid JSON in Condition Agent Structured Output: ' + exception)
|
||||||
model = llm
|
model = llm
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,6 @@ import {
|
||||||
transformBracesWithColon
|
transformBracesWithColon
|
||||||
} from '../../../src/utils'
|
} from '../../../src/utils'
|
||||||
import {
|
import {
|
||||||
ExtractTool,
|
|
||||||
convertStructuredSchemaToZod,
|
convertStructuredSchemaToZod,
|
||||||
customGet,
|
customGet,
|
||||||
getVM,
|
getVM,
|
||||||
|
|
@ -37,7 +36,6 @@ import {
|
||||||
restructureMessages,
|
restructureMessages,
|
||||||
checkMessageHistory
|
checkMessageHistory
|
||||||
} from '../commonUtils'
|
} from '../commonUtils'
|
||||||
import { ChatGoogleGenerativeAI } from '../../chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI'
|
|
||||||
|
|
||||||
const TAB_IDENTIFIER = 'selectedUpdateStateMemoryTab'
|
const TAB_IDENTIFIER = 'selectedUpdateStateMemoryTab'
|
||||||
const customOutputFuncDesc = `This is only applicable when you have a custom State at the START node. After agent execution, you might want to update the State values`
|
const customOutputFuncDesc = `This is only applicable when you have a custom State at the START node. After agent execution, you might want to update the State values`
|
||||||
|
|
@ -513,19 +511,8 @@ async function createAgent(
|
||||||
try {
|
try {
|
||||||
const structuredOutput = z.object(convertStructuredSchemaToZod(llmStructuredOutput))
|
const structuredOutput = z.object(convertStructuredSchemaToZod(llmStructuredOutput))
|
||||||
|
|
||||||
if (llm instanceof ChatGoogleGenerativeAI) {
|
|
||||||
const tool = new ExtractTool({
|
|
||||||
schema: structuredOutput
|
|
||||||
})
|
|
||||||
// @ts-ignore
|
|
||||||
const modelWithTool = llm.bind({
|
|
||||||
tools: [tool]
|
|
||||||
}) as any
|
|
||||||
llm = modelWithTool
|
|
||||||
} else {
|
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
llm = llm.withStructuredOutput(structuredOutput)
|
llm = llm.withStructuredOutput(structuredOutput)
|
||||||
}
|
|
||||||
} catch (exception) {
|
} catch (exception) {
|
||||||
console.error(exception)
|
console.error(exception)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -71,24 +71,13 @@ export const generateFollowUpPrompts = async (
|
||||||
return structuredResponse
|
return structuredResponse
|
||||||
}
|
}
|
||||||
case FollowUpPromptProvider.GOOGLE_GENAI: {
|
case FollowUpPromptProvider.GOOGLE_GENAI: {
|
||||||
const llm = new ChatGoogleGenerativeAI({
|
const model = new ChatGoogleGenerativeAI({
|
||||||
apiKey: credentialData.googleGenerativeAPIKey,
|
apiKey: credentialData.googleGenerativeAPIKey,
|
||||||
model: providerConfig.modelName,
|
model: providerConfig.modelName,
|
||||||
temperature: parseFloat(`${providerConfig.temperature}`)
|
temperature: parseFloat(`${providerConfig.temperature}`)
|
||||||
})
|
})
|
||||||
// use structured output parser because withStructuredOutput is not working
|
const structuredLLM = model.withStructuredOutput(FollowUpPromptType)
|
||||||
const parser = StructuredOutputParser.fromZodSchema(FollowUpPromptType)
|
const structuredResponse = await structuredLLM.invoke(followUpPromptsPrompt)
|
||||||
const formatInstructions = parser.getFormatInstructions()
|
|
||||||
const prompt = PromptTemplate.fromTemplate(`
|
|
||||||
${providerConfig.prompt}
|
|
||||||
|
|
||||||
{format_instructions}
|
|
||||||
`)
|
|
||||||
const chain = prompt.pipe(llm).pipe(parser)
|
|
||||||
const structuredResponse = await chain.invoke({
|
|
||||||
history: apiMessageContent,
|
|
||||||
format_instructions: formatInstructions
|
|
||||||
})
|
|
||||||
return structuredResponse
|
return structuredResponse
|
||||||
}
|
}
|
||||||
case FollowUpPromptProvider.MISTRALAI: {
|
case FollowUpPromptProvider.MISTRALAI: {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue