Compare commits
2 Commits
main
...
chore/Goog
| Author | SHA1 | Date |
|---|---|---|
|
|
dc982c2110 | |
|
|
392c285d48 |
|
|
@ -89,7 +89,7 @@
|
|||
"resolutions": {
|
||||
"@google/generative-ai": "^0.24.0",
|
||||
"@grpc/grpc-js": "^1.10.10",
|
||||
"@langchain/core": "0.3.37",
|
||||
"@langchain/core": "0.3.61",
|
||||
"@qdrant/openapi-typescript-fetch": "1.2.6",
|
||||
"openai": "4.96.0",
|
||||
"protobufjs": "7.4.0"
|
||||
|
|
|
|||
|
|
@ -524,7 +524,7 @@ class Agent_Agentflow implements INode {
|
|||
}
|
||||
const componentNode = options.componentNodes[agentSelectedTool]
|
||||
|
||||
const jsonSchema = zodToJsonSchema(tool.schema)
|
||||
const jsonSchema = zodToJsonSchema(tool.schema as any)
|
||||
if (jsonSchema.$schema) {
|
||||
delete jsonSchema.$schema
|
||||
}
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ class Tool_Agentflow implements INode {
|
|||
toolInputArgs = { properties: allProperties }
|
||||
} else {
|
||||
// Handle single tool instance
|
||||
toolInputArgs = toolInstance.schema ? zodToJsonSchema(toolInstance.schema) : {}
|
||||
toolInputArgs = toolInstance.schema ? zodToJsonSchema(toolInstance.schema as any) : {}
|
||||
}
|
||||
|
||||
if (toolInputArgs && Object.keys(toolInputArgs).length > 0) {
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@ import { BaseCache } from '@langchain/core/caches'
|
|||
import { ICommonObject, IMultiModalOption, INode, INodeData, INodeOptionsValue, INodeParams } from '../../../src/Interface'
|
||||
import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
|
||||
import { getModels, MODEL_TYPE } from '../../../src/modelLoader'
|
||||
import { ChatGoogleGenerativeAI } from './FlowiseChatGoogleGenerativeAI'
|
||||
import { GoogleGenerativeAIChatInput } from '@langchain/google-genai'
|
||||
import { ChatGoogleGenerativeAI, GoogleGenerativeAIChatInput } from './FlowiseChatGoogleGenerativeAI'
|
||||
|
||||
class GoogleGenerativeAI_ChatModels implements INode {
|
||||
label: string
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,733 +0,0 @@
|
|||
/** Disabled due to the withStructuredOutput
|
||||
|
||||
import { BaseMessage, AIMessage, AIMessageChunk, isBaseMessage, ChatMessage, MessageContentComplex } from '@langchain/core/messages'
|
||||
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
|
||||
import { BaseChatModel, type BaseChatModelParams } from '@langchain/core/language_models/chat_models'
|
||||
import { ChatGeneration, ChatGenerationChunk, ChatResult } from '@langchain/core/outputs'
|
||||
import { ToolCallChunk } from '@langchain/core/messages/tool'
|
||||
import { NewTokenIndices } from '@langchain/core/callbacks/base'
|
||||
import {
|
||||
EnhancedGenerateContentResponse,
|
||||
Content,
|
||||
Part,
|
||||
Tool,
|
||||
GenerativeModel,
|
||||
GoogleGenerativeAI as GenerativeAI
|
||||
} from '@google/generative-ai'
|
||||
import type {
|
||||
FunctionCallPart,
|
||||
FunctionResponsePart,
|
||||
SafetySetting,
|
||||
UsageMetadata,
|
||||
FunctionDeclarationsTool as GoogleGenerativeAIFunctionDeclarationsTool,
|
||||
GenerateContentRequest
|
||||
} from '@google/generative-ai'
|
||||
import { ICommonObject, IMultiModalOption, IVisionChatModal } from '../../../src'
|
||||
import { StructuredToolInterface } from '@langchain/core/tools'
|
||||
import { isStructuredTool } from '@langchain/core/utils/function_calling'
|
||||
import { zodToJsonSchema } from 'zod-to-json-schema'
|
||||
import { BaseLanguageModelCallOptions } from '@langchain/core/language_models/base'
|
||||
import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager'
|
||||
|
||||
const DEFAULT_IMAGE_MAX_TOKEN = 8192
|
||||
const DEFAULT_IMAGE_MODEL = 'gemini-1.5-flash-latest'
|
||||
|
||||
interface TokenUsage {
|
||||
completionTokens?: number
|
||||
promptTokens?: number
|
||||
totalTokens?: number
|
||||
}
|
||||
|
||||
interface GoogleGenerativeAIChatCallOptions extends BaseLanguageModelCallOptions {
|
||||
tools?: StructuredToolInterface[] | GoogleGenerativeAIFunctionDeclarationsTool[]
|
||||
streamUsage?: boolean
|
||||
}
|
||||
|
||||
export interface GoogleGenerativeAIChatInput extends BaseChatModelParams, Pick<GoogleGenerativeAIChatCallOptions, 'streamUsage'> {
|
||||
modelName?: string
|
||||
model?: string
|
||||
temperature?: number
|
||||
maxOutputTokens?: number
|
||||
topP?: number
|
||||
topK?: number
|
||||
stopSequences?: string[]
|
||||
safetySettings?: SafetySetting[]
|
||||
apiKey?: string
|
||||
apiVersion?: string
|
||||
baseUrl?: string
|
||||
streaming?: boolean
|
||||
}
|
||||
|
||||
class LangchainChatGoogleGenerativeAI
|
||||
extends BaseChatModel<GoogleGenerativeAIChatCallOptions, AIMessageChunk>
|
||||
implements GoogleGenerativeAIChatInput
|
||||
{
|
||||
modelName = 'gemini-pro'
|
||||
|
||||
temperature?: number
|
||||
|
||||
maxOutputTokens?: number
|
||||
|
||||
topP?: number
|
||||
|
||||
topK?: number
|
||||
|
||||
stopSequences: string[] = []
|
||||
|
||||
safetySettings?: SafetySetting[]
|
||||
|
||||
apiKey?: string
|
||||
|
||||
baseUrl?: string
|
||||
|
||||
streaming = false
|
||||
|
||||
streamUsage = true
|
||||
|
||||
private client: GenerativeModel
|
||||
|
||||
private contextCache?: FlowiseGoogleAICacheManager
|
||||
|
||||
get _isMultimodalModel() {
|
||||
return true
|
||||
}
|
||||
|
||||
constructor(fields?: GoogleGenerativeAIChatInput) {
|
||||
super(fields ?? {})
|
||||
|
||||
this.modelName = fields?.model?.replace(/^models\//, '') ?? fields?.modelName?.replace(/^models\//, '') ?? 'gemini-pro'
|
||||
|
||||
this.maxOutputTokens = fields?.maxOutputTokens ?? this.maxOutputTokens
|
||||
|
||||
if (this.maxOutputTokens && this.maxOutputTokens < 0) {
|
||||
throw new Error('`maxOutputTokens` must be a positive integer')
|
||||
}
|
||||
|
||||
this.temperature = fields?.temperature ?? this.temperature
|
||||
if (this.temperature && (this.temperature < 0 || this.temperature > 1)) {
|
||||
throw new Error('`temperature` must be in the range of [0.0,1.0]')
|
||||
}
|
||||
|
||||
this.topP = fields?.topP ?? this.topP
|
||||
if (this.topP && this.topP < 0) {
|
||||
throw new Error('`topP` must be a positive integer')
|
||||
}
|
||||
|
||||
if (this.topP && this.topP > 1) {
|
||||
throw new Error('`topP` must be below 1.')
|
||||
}
|
||||
|
||||
this.topK = fields?.topK ?? this.topK
|
||||
if (this.topK && this.topK < 0) {
|
||||
throw new Error('`topK` must be a positive integer')
|
||||
}
|
||||
|
||||
this.stopSequences = fields?.stopSequences ?? this.stopSequences
|
||||
|
||||
this.apiKey = fields?.apiKey ?? process.env['GOOGLE_API_KEY']
|
||||
if (!this.apiKey) {
|
||||
throw new Error(
|
||||
'Please set an API key for Google GenerativeAI ' +
|
||||
'in the environment variable GOOGLE_API_KEY ' +
|
||||
'or in the `apiKey` field of the ' +
|
||||
'ChatGoogleGenerativeAI constructor'
|
||||
)
|
||||
}
|
||||
|
||||
this.safetySettings = fields?.safetySettings ?? this.safetySettings
|
||||
if (this.safetySettings && this.safetySettings.length > 0) {
|
||||
const safetySettingsSet = new Set(this.safetySettings.map((s) => s.category))
|
||||
if (safetySettingsSet.size !== this.safetySettings.length) {
|
||||
throw new Error('The categories in `safetySettings` array must be unique')
|
||||
}
|
||||
}
|
||||
|
||||
this.streaming = fields?.streaming ?? this.streaming
|
||||
|
||||
this.streamUsage = fields?.streamUsage ?? this.streamUsage
|
||||
|
||||
this.getClient()
|
||||
}
|
||||
|
||||
async getClient(prompt?: Content[], tools?: Tool[]) {
|
||||
this.client = new GenerativeAI(this.apiKey ?? '').getGenerativeModel(
|
||||
{
|
||||
model: this.modelName,
|
||||
tools,
|
||||
safetySettings: this.safetySettings as SafetySetting[],
|
||||
generationConfig: {
|
||||
candidateCount: 1,
|
||||
stopSequences: this.stopSequences,
|
||||
maxOutputTokens: this.maxOutputTokens,
|
||||
temperature: this.temperature,
|
||||
topP: this.topP,
|
||||
topK: this.topK
|
||||
}
|
||||
},
|
||||
{
|
||||
baseUrl: this.baseUrl
|
||||
}
|
||||
)
|
||||
if (this.contextCache) {
|
||||
const cachedContent = await this.contextCache.lookup({
|
||||
contents: prompt ? [{ ...prompt[0], parts: prompt[0].parts.slice(0, 1) }] : [],
|
||||
model: this.modelName,
|
||||
tools
|
||||
})
|
||||
this.client.cachedContent = cachedContent as any
|
||||
}
|
||||
}
|
||||
|
||||
_combineLLMOutput() {
|
||||
return []
|
||||
}
|
||||
|
||||
_llmType() {
|
||||
return 'googlegenerativeai'
|
||||
}
|
||||
|
||||
override bindTools(tools: (StructuredToolInterface | Record<string, unknown>)[], kwargs?: Partial<ICommonObject>) {
|
||||
//@ts-ignore
|
||||
return this.bind({ tools: convertToGeminiTools(tools), ...kwargs })
|
||||
}
|
||||
|
||||
invocationParams(options?: this['ParsedCallOptions']): Omit<GenerateContentRequest, 'contents'> {
|
||||
const tools = options?.tools as GoogleGenerativeAIFunctionDeclarationsTool[] | StructuredToolInterface[] | undefined
|
||||
if (Array.isArray(tools) && !tools.some((t: any) => !('lc_namespace' in t))) {
|
||||
return {
|
||||
tools: convertToGeminiTools(options?.tools as StructuredToolInterface[]) as any
|
||||
}
|
||||
}
|
||||
return {
|
||||
tools: options?.tools as GoogleGenerativeAIFunctionDeclarationsTool[] | undefined
|
||||
}
|
||||
}
|
||||
|
||||
convertFunctionResponse(prompts: Content[]) {
|
||||
for (let i = 0; i < prompts.length; i += 1) {
|
||||
if (prompts[i].role === 'function') {
|
||||
if (prompts[i - 1].role === 'model') {
|
||||
const toolName = prompts[i - 1].parts[0].functionCall?.name ?? ''
|
||||
prompts[i].parts = [
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
response: {
|
||||
name: toolName,
|
||||
content: prompts[i].parts[0].text
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setContextCache(contextCache: FlowiseGoogleAICacheManager): void {
|
||||
this.contextCache = contextCache
|
||||
}
|
||||
|
||||
async getNumTokens(prompt: BaseMessage[]) {
|
||||
const contents = convertBaseMessagesToContent(prompt, this._isMultimodalModel)
|
||||
const { totalTokens } = await this.client.countTokens({ contents })
|
||||
return totalTokens
|
||||
}
|
||||
|
||||
async _generateNonStreaming(
|
||||
prompt: Content[],
|
||||
options: this['ParsedCallOptions'],
|
||||
_runManager?: CallbackManagerForLLMRun
|
||||
): Promise<ChatResult> {
|
||||
//@ts-ignore
|
||||
const tools = options.tools ?? []
|
||||
|
||||
this.convertFunctionResponse(prompt)
|
||||
|
||||
if (tools.length > 0) {
|
||||
await this.getClient(prompt, tools as Tool[])
|
||||
} else {
|
||||
await this.getClient(prompt)
|
||||
}
|
||||
const res = await this.caller.callWithOptions({ signal: options?.signal }, async () => {
|
||||
let output
|
||||
try {
|
||||
output = await this.client.generateContent({
|
||||
contents: prompt
|
||||
})
|
||||
} catch (e: any) {
|
||||
if (e.message?.includes('400 Bad Request')) {
|
||||
e.status = 400
|
||||
}
|
||||
throw e
|
||||
}
|
||||
return output
|
||||
})
|
||||
const generationResult = mapGenerateContentResultToChatResult(res.response)
|
||||
await _runManager?.handleLLMNewToken(generationResult.generations?.length ? generationResult.generations[0].text : '')
|
||||
return generationResult
|
||||
}
|
||||
|
||||
async _generate(
|
||||
messages: BaseMessage[],
|
||||
options: this['ParsedCallOptions'],
|
||||
runManager?: CallbackManagerForLLMRun
|
||||
): Promise<ChatResult> {
|
||||
let prompt = convertBaseMessagesToContent(messages, this._isMultimodalModel)
|
||||
prompt = checkIfEmptyContentAndSameRole(prompt)
|
||||
|
||||
// Handle streaming
|
||||
if (this.streaming) {
|
||||
const tokenUsage: TokenUsage = {}
|
||||
const stream = this._streamResponseChunks(messages, options, runManager)
|
||||
const finalChunks: Record<number, ChatGenerationChunk> = {}
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const index = (chunk.generationInfo as NewTokenIndices)?.completion ?? 0
|
||||
if (finalChunks[index] === undefined) {
|
||||
finalChunks[index] = chunk
|
||||
} else {
|
||||
finalChunks[index] = finalChunks[index].concat(chunk)
|
||||
}
|
||||
}
|
||||
const generations = Object.entries(finalChunks)
|
||||
.sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10))
|
||||
.map(([_, value]) => value)
|
||||
|
||||
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } }
|
||||
}
|
||||
return this._generateNonStreaming(prompt, options, runManager)
|
||||
}
|
||||
|
||||
async *_streamResponseChunks(
|
||||
messages: BaseMessage[],
|
||||
options: this['ParsedCallOptions'],
|
||||
runManager?: CallbackManagerForLLMRun
|
||||
): AsyncGenerator<ChatGenerationChunk> {
|
||||
let prompt = convertBaseMessagesToContent(messages, this._isMultimodalModel)
|
||||
prompt = checkIfEmptyContentAndSameRole(prompt)
|
||||
|
||||
const parameters = this.invocationParams(options)
|
||||
const request = {
|
||||
...parameters,
|
||||
contents: prompt
|
||||
}
|
||||
|
||||
const tools = options.tools ?? []
|
||||
if (tools.length > 0) {
|
||||
await this.getClient(prompt, tools as Tool[])
|
||||
} else {
|
||||
await this.getClient(prompt)
|
||||
}
|
||||
|
||||
const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => {
|
||||
const { stream } = await this.client.generateContentStream(request)
|
||||
return stream
|
||||
})
|
||||
|
||||
let usageMetadata: UsageMetadata | ICommonObject | undefined
|
||||
let index = 0
|
||||
for await (const response of stream) {
|
||||
if ('usageMetadata' in response && this.streamUsage !== false && options.streamUsage !== false) {
|
||||
const genAIUsageMetadata = response.usageMetadata as {
|
||||
promptTokenCount: number
|
||||
candidatesTokenCount: number
|
||||
totalTokenCount: number
|
||||
}
|
||||
if (!usageMetadata) {
|
||||
usageMetadata = {
|
||||
input_tokens: genAIUsageMetadata.promptTokenCount,
|
||||
output_tokens: genAIUsageMetadata.candidatesTokenCount,
|
||||
total_tokens: genAIUsageMetadata.totalTokenCount
|
||||
}
|
||||
} else {
|
||||
// Under the hood, LangChain combines the prompt tokens. Google returns the updated
|
||||
// total each time, so we need to find the difference between the tokens.
|
||||
const outputTokenDiff = genAIUsageMetadata.candidatesTokenCount - (usageMetadata as ICommonObject).output_tokens
|
||||
usageMetadata = {
|
||||
input_tokens: 0,
|
||||
output_tokens: outputTokenDiff,
|
||||
total_tokens: outputTokenDiff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const chunk = convertResponseContentToChatGenerationChunk(response, {
|
||||
usageMetadata: usageMetadata as UsageMetadata,
|
||||
index
|
||||
})
|
||||
index += 1
|
||||
if (!chunk) {
|
||||
continue
|
||||
}
|
||||
|
||||
yield chunk
|
||||
await runManager?.handleLLMNewToken(chunk.text ?? '')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class ChatGoogleGenerativeAI extends LangchainChatGoogleGenerativeAI implements IVisionChatModal {
|
||||
configuredModel: string
|
||||
configuredMaxToken?: number
|
||||
multiModalOption: IMultiModalOption
|
||||
id: string
|
||||
|
||||
constructor(id: string, fields?: GoogleGenerativeAIChatInput) {
|
||||
super(fields)
|
||||
this.id = id
|
||||
this.configuredModel = fields?.modelName ?? ''
|
||||
this.configuredMaxToken = fields?.maxOutputTokens
|
||||
}
|
||||
|
||||
revertToOriginalModel(): void {
|
||||
this.modelName = this.configuredModel
|
||||
this.maxOutputTokens = this.configuredMaxToken
|
||||
}
|
||||
|
||||
setMultiModalOption(multiModalOption: IMultiModalOption): void {
|
||||
this.multiModalOption = multiModalOption
|
||||
}
|
||||
|
||||
setVisionModel(): void {
|
||||
if (this.modelName === 'gemini-1.0-pro-latest') {
|
||||
this.modelName = DEFAULT_IMAGE_MODEL
|
||||
this.maxOutputTokens = this.configuredMaxToken ? this.configuredMaxToken : DEFAULT_IMAGE_MAX_TOKEN
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function messageContentMedia(content: MessageContentComplex): Part {
|
||||
if ('mimeType' in content && 'data' in content) {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: content.mimeType,
|
||||
data: content.data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error('Invalid media content')
|
||||
}
|
||||
|
||||
function getMessageAuthor(message: BaseMessage) {
|
||||
const type = message._getType()
|
||||
if (ChatMessage.isInstance(message)) {
|
||||
return message.role
|
||||
}
|
||||
return message.name ?? type
|
||||
}
|
||||
|
||||
function convertAuthorToRole(author: string) {
|
||||
switch (author.toLowerCase()) {
|
||||
case 'ai':
|
||||
case 'assistant':
|
||||
case 'model':
|
||||
return 'model'
|
||||
case 'function':
|
||||
case 'tool':
|
||||
return 'function'
|
||||
case 'system':
|
||||
case 'human':
|
||||
default:
|
||||
return 'user'
|
||||
}
|
||||
}
|
||||
|
||||
function convertMessageContentToParts(message: BaseMessage, isMultimodalModel: boolean): Part[] {
|
||||
if (typeof message.content === 'string' && message.content !== '') {
|
||||
return [{ text: message.content }]
|
||||
}
|
||||
|
||||
let functionCalls: FunctionCallPart[] = []
|
||||
let functionResponses: FunctionResponsePart[] = []
|
||||
let messageParts: Part[] = []
|
||||
|
||||
if ('tool_calls' in message && Array.isArray(message.tool_calls) && message.tool_calls.length > 0) {
|
||||
functionCalls = message.tool_calls.map((tc) => ({
|
||||
functionCall: {
|
||||
name: tc.name,
|
||||
args: tc.args
|
||||
}
|
||||
}))
|
||||
} else if (message._getType() === 'tool' && message.name && message.content) {
|
||||
functionResponses = [
|
||||
{
|
||||
functionResponse: {
|
||||
name: message.name,
|
||||
response: message.content
|
||||
}
|
||||
}
|
||||
]
|
||||
} else if (Array.isArray(message.content)) {
|
||||
messageParts = message.content.map((c) => {
|
||||
if (c.type === 'text') {
|
||||
return {
|
||||
text: c.text
|
||||
}
|
||||
}
|
||||
|
||||
if (c.type === 'image_url') {
|
||||
if (!isMultimodalModel) {
|
||||
throw new Error(`This model does not support images`)
|
||||
}
|
||||
let source
|
||||
if (typeof c.image_url === 'string') {
|
||||
source = c.image_url
|
||||
} else if (typeof c.image_url === 'object' && 'url' in c.image_url) {
|
||||
source = c.image_url.url
|
||||
} else {
|
||||
throw new Error('Please provide image as base64 encoded data URL')
|
||||
}
|
||||
const [dm, data] = source.split(',')
|
||||
if (!dm.startsWith('data:')) {
|
||||
throw new Error('Please provide image as base64 encoded data URL')
|
||||
}
|
||||
|
||||
const [mimeType, encoding] = dm.replace(/^data:/, '').split(';')
|
||||
if (encoding !== 'base64') {
|
||||
throw new Error('Please provide image as base64 encoded data URL')
|
||||
}
|
||||
|
||||
return {
|
||||
inlineData: {
|
||||
data,
|
||||
mimeType
|
||||
}
|
||||
}
|
||||
} else if (c.type === 'media') {
|
||||
return messageContentMedia(c)
|
||||
} else if (c.type === 'tool_use') {
|
||||
return {
|
||||
functionCall: {
|
||||
name: c.name,
|
||||
args: c.input
|
||||
}
|
||||
}
|
||||
}
|
||||
throw new Error(`Unknown content type ${(c as { type: string }).type}`)
|
||||
})
|
||||
}
|
||||
|
||||
return [...messageParts, ...functionCalls, ...functionResponses]
|
||||
}
|
||||
|
||||
// This is a dedicated logic for Multi Agent Supervisor to handle the case where the content is empty, and the role is the same
|
||||
function checkIfEmptyContentAndSameRole(contents: Content[]) {
|
||||
let prevRole = ''
|
||||
const validContents: Content[] = []
|
||||
|
||||
for (const content of contents) {
|
||||
// Skip only if completely empty
|
||||
if (!content.parts || !content.parts.length) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure role is always either 'user' or 'model'
|
||||
content.role = content.role === 'model' ? 'model' : 'user'
|
||||
|
||||
// Handle consecutive messages
|
||||
if (content.role === prevRole && validContents.length > 0) {
|
||||
// Merge with previous content if same role
|
||||
validContents[validContents.length - 1].parts.push(...content.parts)
|
||||
continue
|
||||
}
|
||||
|
||||
validContents.push(content)
|
||||
prevRole = content.role
|
||||
}
|
||||
|
||||
return validContents
|
||||
}
|
||||
|
||||
function convertBaseMessagesToContent(messages: BaseMessage[], isMultimodalModel: boolean) {
|
||||
return messages.reduce<{
|
||||
content: Content[]
|
||||
mergeWithPreviousContent: boolean
|
||||
}>(
|
||||
(acc, message, index) => {
|
||||
if (!isBaseMessage(message)) {
|
||||
throw new Error('Unsupported message input')
|
||||
}
|
||||
const author = getMessageAuthor(message)
|
||||
if (author === 'system' && index !== 0) {
|
||||
throw new Error('System message should be the first one')
|
||||
}
|
||||
const role = convertAuthorToRole(author)
|
||||
|
||||
const prevContent = acc.content[acc.content.length]
|
||||
if (!acc.mergeWithPreviousContent && prevContent && prevContent.role === role) {
|
||||
throw new Error('Google Generative AI requires alternate messages between authors')
|
||||
}
|
||||
|
||||
const parts = convertMessageContentToParts(message, isMultimodalModel)
|
||||
|
||||
if (acc.mergeWithPreviousContent) {
|
||||
const prevContent = acc.content[acc.content.length - 1]
|
||||
if (!prevContent) {
|
||||
throw new Error('There was a problem parsing your system message. Please try a prompt without one.')
|
||||
}
|
||||
prevContent.parts.push(...parts)
|
||||
|
||||
return {
|
||||
mergeWithPreviousContent: false,
|
||||
content: acc.content
|
||||
}
|
||||
}
|
||||
let actualRole = role
|
||||
if (actualRole === 'function' || actualRole === 'tool') {
|
||||
// GenerativeAI API will throw an error if the role is not "user" or "model."
|
||||
actualRole = 'user'
|
||||
}
|
||||
const content: Content = {
|
||||
role: actualRole,
|
||||
parts
|
||||
}
|
||||
return {
|
||||
mergeWithPreviousContent: author === 'system',
|
||||
content: [...acc.content, content]
|
||||
}
|
||||
},
|
||||
{ content: [], mergeWithPreviousContent: false }
|
||||
).content
|
||||
}
|
||||
|
||||
function mapGenerateContentResultToChatResult(
|
||||
response: EnhancedGenerateContentResponse,
|
||||
extra?: {
|
||||
usageMetadata: UsageMetadata | undefined
|
||||
}
|
||||
): ChatResult {
|
||||
// if rejected or error, return empty generations with reason in filters
|
||||
if (!response.candidates || response.candidates.length === 0 || !response.candidates[0]) {
|
||||
return {
|
||||
generations: [],
|
||||
llmOutput: {
|
||||
filters: response.promptFeedback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const functionCalls = response.functionCalls()
|
||||
const [candidate] = response.candidates
|
||||
const { content, ...generationInfo } = candidate
|
||||
const text = content?.parts[0]?.text ?? ''
|
||||
|
||||
const generation: ChatGeneration = {
|
||||
text,
|
||||
message: new AIMessage({
|
||||
content: text,
|
||||
tool_calls: functionCalls,
|
||||
additional_kwargs: {
|
||||
...generationInfo
|
||||
},
|
||||
usage_metadata: extra?.usageMetadata as any
|
||||
}),
|
||||
generationInfo
|
||||
}
|
||||
|
||||
return {
|
||||
generations: [generation]
|
||||
}
|
||||
}
|
||||
|
||||
function convertResponseContentToChatGenerationChunk(
|
||||
response: EnhancedGenerateContentResponse,
|
||||
extra: {
|
||||
usageMetadata?: UsageMetadata | undefined
|
||||
index: number
|
||||
}
|
||||
): ChatGenerationChunk | null {
|
||||
if (!response || !response.candidates || response.candidates.length === 0) {
|
||||
return null
|
||||
}
|
||||
const functionCalls = response.functionCalls()
|
||||
const [candidate] = response.candidates
|
||||
const { content, ...generationInfo } = candidate
|
||||
const text = content?.parts?.[0]?.text ?? ''
|
||||
|
||||
const toolCallChunks: ToolCallChunk[] = []
|
||||
if (functionCalls) {
|
||||
toolCallChunks.push(
|
||||
...functionCalls.map((fc) => ({
|
||||
...fc,
|
||||
args: JSON.stringify(fc.args),
|
||||
index: extra.index
|
||||
}))
|
||||
)
|
||||
}
|
||||
return new ChatGenerationChunk({
|
||||
text,
|
||||
message: new AIMessageChunk({
|
||||
content: text,
|
||||
name: !content ? undefined : content.role,
|
||||
tool_call_chunks: toolCallChunks,
|
||||
// Each chunk can have unique "generationInfo", and merging strategy is unclear,
|
||||
// so leave blank for now.
|
||||
additional_kwargs: {},
|
||||
usage_metadata: extra.usageMetadata as any
|
||||
}),
|
||||
generationInfo
|
||||
})
|
||||
}
|
||||
|
||||
function zodToGeminiParameters(zodObj: any) {
|
||||
// Gemini doesn't accept either the $schema or additionalProperties
|
||||
// attributes, so we need to explicitly remove them.
|
||||
const jsonSchema: any = zodToJsonSchema(zodObj)
|
||||
// eslint-disable-next-line unused-imports/no-unused-vars
|
||||
const { $schema, additionalProperties, ...rest } = jsonSchema
|
||||
|
||||
// Ensure all properties have type specified
|
||||
if (rest.properties) {
|
||||
Object.keys(rest.properties).forEach((key) => {
|
||||
const prop = rest.properties[key]
|
||||
|
||||
// Handle enum types
|
||||
if (prop.enum?.length) {
|
||||
rest.properties[key] = {
|
||||
type: 'string',
|
||||
format: 'enum',
|
||||
enum: prop.enum
|
||||
}
|
||||
}
|
||||
// Handle missing type
|
||||
else if (!prop.type && !prop.oneOf && !prop.anyOf && !prop.allOf) {
|
||||
// Infer type from other properties
|
||||
if (prop.minimum !== undefined || prop.maximum !== undefined) {
|
||||
prop.type = 'number'
|
||||
} else if (prop.format === 'date-time') {
|
||||
prop.type = 'string'
|
||||
} else if (prop.items) {
|
||||
prop.type = 'array'
|
||||
} else if (prop.properties) {
|
||||
prop.type = 'object'
|
||||
} else {
|
||||
// Default to string if type can't be inferred
|
||||
prop.type = 'string'
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return rest
|
||||
}
|
||||
|
||||
function convertToGeminiTools(structuredTools: (StructuredToolInterface | Record<string, unknown>)[]) {
|
||||
return [
|
||||
{
|
||||
functionDeclarations: structuredTools.map((structuredTool) => {
|
||||
if (isStructuredTool(structuredTool)) {
|
||||
const jsonSchema = zodToGeminiParameters(structuredTool.schema)
|
||||
return {
|
||||
name: structuredTool.name,
|
||||
description: structuredTool.description,
|
||||
parameters: jsonSchema
|
||||
}
|
||||
}
|
||||
return structuredTool
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
*/
|
||||
|
|
@ -0,0 +1,630 @@
|
|||
import {
|
||||
EnhancedGenerateContentResponse,
|
||||
Content,
|
||||
Part,
|
||||
type FunctionDeclarationsTool as GoogleGenerativeAIFunctionDeclarationsTool,
|
||||
type FunctionDeclaration as GenerativeAIFunctionDeclaration,
|
||||
POSSIBLE_ROLES,
|
||||
FunctionCallPart,
|
||||
TextPart,
|
||||
FileDataPart,
|
||||
InlineDataPart
|
||||
} from '@google/generative-ai'
|
||||
import {
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
MessageContent,
|
||||
MessageContentComplex,
|
||||
UsageMetadata,
|
||||
isAIMessage,
|
||||
isBaseMessage,
|
||||
isToolMessage,
|
||||
StandardContentBlockConverter,
|
||||
parseBase64DataUrl,
|
||||
convertToProviderContentBlock,
|
||||
isDataContentBlock
|
||||
} from '@langchain/core/messages'
|
||||
import { ChatGeneration, ChatGenerationChunk, ChatResult } from '@langchain/core/outputs'
|
||||
import { isLangChainTool } from '@langchain/core/utils/function_calling'
|
||||
import { isOpenAITool } from '@langchain/core/language_models/base'
|
||||
import { ToolCallChunk } from '@langchain/core/messages/tool'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import { jsonSchemaToGeminiParameters, schemaToGenerativeAIParameters } from './zod_to_genai_parameters.js'
|
||||
import { GoogleGenerativeAIToolType } from './types.js'
|
||||
|
||||
export function getMessageAuthor(message: BaseMessage) {
|
||||
const type = message._getType()
|
||||
if (ChatMessage.isInstance(message)) {
|
||||
return message.role
|
||||
}
|
||||
if (type === 'tool') {
|
||||
return type
|
||||
}
|
||||
return message.name ?? type
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps a message type to a Google Generative AI chat author.
|
||||
* @param message The message to map.
|
||||
* @param model The model to use for mapping.
|
||||
* @returns The message type mapped to a Google Generative AI chat author.
|
||||
*/
|
||||
export function convertAuthorToRole(author: string): (typeof POSSIBLE_ROLES)[number] {
|
||||
switch (author) {
|
||||
/**
|
||||
* Note: Gemini currently is not supporting system messages
|
||||
* we will convert them to human messages and merge with following
|
||||
* */
|
||||
case 'supervisor':
|
||||
case 'ai':
|
||||
case 'model': // getMessageAuthor returns message.name. code ex.: return message.name ?? type;
|
||||
return 'model'
|
||||
case 'system':
|
||||
return 'system'
|
||||
case 'human':
|
||||
return 'user'
|
||||
case 'tool':
|
||||
case 'function':
|
||||
return 'function'
|
||||
default:
|
||||
return 'user' // return user as default instead of throwing error
|
||||
}
|
||||
}
|
||||
|
||||
function messageContentMedia(content: MessageContentComplex): Part {
|
||||
if ('mimeType' in content && 'data' in content) {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: content.mimeType,
|
||||
data: content.data
|
||||
}
|
||||
}
|
||||
}
|
||||
if ('mimeType' in content && 'fileUri' in content) {
|
||||
return {
|
||||
fileData: {
|
||||
mimeType: content.mimeType,
|
||||
fileUri: content.fileUri
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error('Invalid media content')
|
||||
}
|
||||
|
||||
function inferToolNameFromPreviousMessages(message: ToolMessage | ToolMessageChunk, previousMessages: BaseMessage[]): string | undefined {
|
||||
return previousMessages
|
||||
.map((msg) => {
|
||||
if (isAIMessage(msg)) {
|
||||
return msg.tool_calls ?? []
|
||||
}
|
||||
return []
|
||||
})
|
||||
.flat()
|
||||
.find((toolCall) => {
|
||||
return toolCall.id === message.tool_call_id
|
||||
})?.name
|
||||
}
|
||||
|
||||
function _getStandardContentBlockConverter(isMultimodalModel: boolean) {
|
||||
const standardContentBlockConverter: StandardContentBlockConverter<{
|
||||
text: TextPart
|
||||
image: FileDataPart | InlineDataPart
|
||||
audio: FileDataPart | InlineDataPart
|
||||
file: FileDataPart | InlineDataPart | TextPart
|
||||
}> = {
|
||||
providerName: 'Google Gemini',
|
||||
|
||||
fromStandardTextBlock(block) {
|
||||
return {
|
||||
text: block.text
|
||||
}
|
||||
},
|
||||
|
||||
fromStandardImageBlock(block): FileDataPart | InlineDataPart {
|
||||
if (!isMultimodalModel) {
|
||||
throw new Error('This model does not support images')
|
||||
}
|
||||
if (block.source_type === 'url') {
|
||||
const data = parseBase64DataUrl({ dataUrl: block.url })
|
||||
if (data) {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: data.mime_type,
|
||||
data: data.data
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
fileData: {
|
||||
mimeType: block.mime_type ?? '',
|
||||
fileUri: block.url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (block.source_type === 'base64') {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: block.mime_type ?? '',
|
||||
data: block.data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(`Unsupported source type: ${block.source_type}`)
|
||||
},
|
||||
|
||||
fromStandardAudioBlock(block): FileDataPart | InlineDataPart {
|
||||
if (!isMultimodalModel) {
|
||||
throw new Error('This model does not support audio')
|
||||
}
|
||||
if (block.source_type === 'url') {
|
||||
const data = parseBase64DataUrl({ dataUrl: block.url })
|
||||
if (data) {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: data.mime_type,
|
||||
data: data.data
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
fileData: {
|
||||
mimeType: block.mime_type ?? '',
|
||||
fileUri: block.url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (block.source_type === 'base64') {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: block.mime_type ?? '',
|
||||
data: block.data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(`Unsupported source type: ${block.source_type}`)
|
||||
},
|
||||
|
||||
fromStandardFileBlock(block): FileDataPart | InlineDataPart | TextPart {
|
||||
if (!isMultimodalModel) {
|
||||
throw new Error('This model does not support files')
|
||||
}
|
||||
if (block.source_type === 'text') {
|
||||
return {
|
||||
text: block.text
|
||||
}
|
||||
}
|
||||
if (block.source_type === 'url') {
|
||||
const data = parseBase64DataUrl({ dataUrl: block.url })
|
||||
if (data) {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: data.mime_type,
|
||||
data: data.data
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
fileData: {
|
||||
mimeType: block.mime_type ?? '',
|
||||
fileUri: block.url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (block.source_type === 'base64') {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: block.mime_type ?? '',
|
||||
data: block.data
|
||||
}
|
||||
}
|
||||
}
|
||||
throw new Error(`Unsupported source type: ${block.source_type}`)
|
||||
}
|
||||
}
|
||||
return standardContentBlockConverter
|
||||
}
|
||||
|
||||
function _convertLangChainContentToPart(content: MessageContentComplex, isMultimodalModel: boolean): Part | undefined {
|
||||
if (isDataContentBlock(content)) {
|
||||
return convertToProviderContentBlock(content, _getStandardContentBlockConverter(isMultimodalModel))
|
||||
}
|
||||
|
||||
if (content.type === 'text') {
|
||||
return { text: content.text }
|
||||
} else if (content.type === 'executableCode') {
|
||||
return { executableCode: content.executableCode }
|
||||
} else if (content.type === 'codeExecutionResult') {
|
||||
return { codeExecutionResult: content.codeExecutionResult }
|
||||
} else if (content.type === 'image_url') {
|
||||
if (!isMultimodalModel) {
|
||||
throw new Error(`This model does not support images`)
|
||||
}
|
||||
let source
|
||||
if (typeof content.image_url === 'string') {
|
||||
source = content.image_url
|
||||
} else if (typeof content.image_url === 'object' && 'url' in content.image_url) {
|
||||
source = content.image_url.url
|
||||
} else {
|
||||
throw new Error('Please provide image as base64 encoded data URL')
|
||||
}
|
||||
const [dm, data] = source.split(',')
|
||||
if (!dm.startsWith('data:')) {
|
||||
throw new Error('Please provide image as base64 encoded data URL')
|
||||
}
|
||||
|
||||
const [mimeType, encoding] = dm.replace(/^data:/, '').split(';')
|
||||
if (encoding !== 'base64') {
|
||||
throw new Error('Please provide image as base64 encoded data URL')
|
||||
}
|
||||
|
||||
return {
|
||||
inlineData: {
|
||||
data,
|
||||
mimeType
|
||||
}
|
||||
}
|
||||
} else if (content.type === 'media') {
|
||||
return messageContentMedia(content)
|
||||
} else if (content.type === 'tool_use') {
|
||||
return {
|
||||
functionCall: {
|
||||
name: content.name,
|
||||
args: content.input
|
||||
}
|
||||
}
|
||||
} else if (
|
||||
content.type?.includes('/') &&
|
||||
// Ensure it's a single slash.
|
||||
content.type.split('/').length === 2 &&
|
||||
'data' in content &&
|
||||
typeof content.data === 'string'
|
||||
) {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: content.type,
|
||||
data: content.data
|
||||
}
|
||||
}
|
||||
} else if ('functionCall' in content) {
|
||||
// No action needed here — function calls will be added later from message.tool_calls
|
||||
return undefined
|
||||
} else {
|
||||
if ('type' in content) {
|
||||
throw new Error(`Unknown content type ${content.type}`)
|
||||
} else {
|
||||
throw new Error(`Unknown content ${JSON.stringify(content)}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function convertMessageContentToParts(message: BaseMessage, isMultimodalModel: boolean, previousMessages: BaseMessage[]): Part[] {
|
||||
if (isToolMessage(message)) {
|
||||
const messageName = message.name ?? inferToolNameFromPreviousMessages(message, previousMessages)
|
||||
if (messageName === undefined) {
|
||||
throw new Error(
|
||||
`Google requires a tool name for each tool call response, and we could not infer a called tool name for ToolMessage "${message.id}" from your passed messages. Please populate a "name" field on that ToolMessage explicitly.`
|
||||
)
|
||||
}
|
||||
|
||||
const result = Array.isArray(message.content)
|
||||
? (message.content.map((c) => _convertLangChainContentToPart(c, isMultimodalModel)).filter((p) => p !== undefined) as Part[])
|
||||
: message.content
|
||||
|
||||
if (message.status === 'error') {
|
||||
return [
|
||||
{
|
||||
functionResponse: {
|
||||
name: messageName,
|
||||
// The API expects an object with an `error` field if the function call fails.
|
||||
// `error` must be a valid object (not a string or array), so we wrap `message.content` here
|
||||
response: { error: { details: result } }
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return [
|
||||
{
|
||||
functionResponse: {
|
||||
name: messageName,
|
||||
// again, can't have a string or array value for `response`, so we wrap it as an object here
|
||||
response: { result }
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
let functionCalls: FunctionCallPart[] = []
|
||||
const messageParts: Part[] = []
|
||||
|
||||
if (typeof message.content === 'string' && message.content) {
|
||||
messageParts.push({ text: message.content })
|
||||
}
|
||||
|
||||
if (Array.isArray(message.content)) {
|
||||
messageParts.push(
|
||||
...(message.content.map((c) => _convertLangChainContentToPart(c, isMultimodalModel)).filter((p) => p !== undefined) as Part[])
|
||||
)
|
||||
}
|
||||
|
||||
if (isAIMessage(message) && message.tool_calls?.length) {
|
||||
functionCalls = message.tool_calls.map((tc) => {
|
||||
return {
|
||||
functionCall: {
|
||||
name: tc.name,
|
||||
args: tc.args
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return [...messageParts, ...functionCalls]
|
||||
}
|
||||
|
||||
export function convertBaseMessagesToContent(
|
||||
messages: BaseMessage[],
|
||||
isMultimodalModel: boolean,
|
||||
convertSystemMessageToHumanContent: boolean = false
|
||||
) {
|
||||
return messages.reduce<{
|
||||
content: Content[]
|
||||
mergeWithPreviousContent: boolean
|
||||
}>(
|
||||
(acc, message, index) => {
|
||||
if (!isBaseMessage(message)) {
|
||||
throw new Error('Unsupported message input')
|
||||
}
|
||||
const author = getMessageAuthor(message)
|
||||
if (author === 'system' && index !== 0) {
|
||||
throw new Error('System message should be the first one')
|
||||
}
|
||||
const role = convertAuthorToRole(author)
|
||||
|
||||
const prevContent = acc.content[acc.content.length]
|
||||
if (!acc.mergeWithPreviousContent && prevContent && prevContent.role === role) {
|
||||
throw new Error('Google Generative AI requires alternate messages between authors')
|
||||
}
|
||||
|
||||
const parts = convertMessageContentToParts(message, isMultimodalModel, messages.slice(0, index))
|
||||
|
||||
if (acc.mergeWithPreviousContent) {
|
||||
const prevContent = acc.content[acc.content.length - 1]
|
||||
if (!prevContent) {
|
||||
throw new Error('There was a problem parsing your system message. Please try a prompt without one.')
|
||||
}
|
||||
prevContent.parts.push(...parts)
|
||||
|
||||
return {
|
||||
mergeWithPreviousContent: false,
|
||||
content: acc.content
|
||||
}
|
||||
}
|
||||
let actualRole = role
|
||||
if (actualRole === 'function' || (actualRole === 'system' && !convertSystemMessageToHumanContent)) {
|
||||
// GenerativeAI API will throw an error if the role is not "user" or "model."
|
||||
actualRole = 'user'
|
||||
}
|
||||
const content: Content = {
|
||||
role: actualRole,
|
||||
parts
|
||||
}
|
||||
return {
|
||||
mergeWithPreviousContent: author === 'system' && !convertSystemMessageToHumanContent,
|
||||
content: [...acc.content, content]
|
||||
}
|
||||
},
|
||||
{ content: [], mergeWithPreviousContent: false }
|
||||
).content
|
||||
}
|
||||
|
||||
export function mapGenerateContentResultToChatResult(
|
||||
response: EnhancedGenerateContentResponse,
|
||||
extra?: {
|
||||
usageMetadata: UsageMetadata | undefined
|
||||
}
|
||||
): ChatResult {
|
||||
// if rejected or error, return empty generations with reason in filters
|
||||
if (!response.candidates || response.candidates.length === 0 || !response.candidates[0]) {
|
||||
return {
|
||||
generations: [],
|
||||
llmOutput: {
|
||||
filters: response.promptFeedback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const functionCalls = response.functionCalls()
|
||||
const [candidate] = response.candidates
|
||||
const { content: candidateContent, ...generationInfo } = candidate
|
||||
let content: MessageContent | undefined
|
||||
|
||||
if (Array.isArray(candidateContent?.parts) && candidateContent.parts.length === 1 && candidateContent.parts[0].text) {
|
||||
content = candidateContent.parts[0].text
|
||||
} else if (Array.isArray(candidateContent?.parts) && candidateContent.parts.length > 0) {
|
||||
content = candidateContent.parts.map((p) => {
|
||||
if ('text' in p) {
|
||||
return {
|
||||
type: 'text',
|
||||
text: p.text
|
||||
}
|
||||
} else if ('executableCode' in p) {
|
||||
return {
|
||||
type: 'executableCode',
|
||||
executableCode: p.executableCode
|
||||
}
|
||||
} else if ('codeExecutionResult' in p) {
|
||||
return {
|
||||
type: 'codeExecutionResult',
|
||||
codeExecutionResult: p.codeExecutionResult
|
||||
}
|
||||
}
|
||||
return p
|
||||
})
|
||||
} else {
|
||||
// no content returned - likely due to abnormal stop reason, e.g. malformed function call
|
||||
content = []
|
||||
}
|
||||
|
||||
let text = ''
|
||||
if (typeof content === 'string') {
|
||||
text = content
|
||||
} else if (Array.isArray(content) && content.length > 0) {
|
||||
const block = content.find((b) => 'text' in b) as { text: string } | undefined
|
||||
text = block?.text ?? text
|
||||
}
|
||||
|
||||
const generation: ChatGeneration = {
|
||||
text,
|
||||
message: new AIMessage({
|
||||
content: content ?? '',
|
||||
tool_calls: functionCalls?.map((fc) => {
|
||||
return {
|
||||
...fc,
|
||||
type: 'tool_call',
|
||||
id: 'id' in fc && typeof fc.id === 'string' ? fc.id : uuidv4()
|
||||
}
|
||||
}),
|
||||
additional_kwargs: {
|
||||
...generationInfo
|
||||
},
|
||||
usage_metadata: extra?.usageMetadata
|
||||
}),
|
||||
generationInfo
|
||||
}
|
||||
|
||||
return {
|
||||
generations: [generation],
|
||||
llmOutput: {
|
||||
tokenUsage: {
|
||||
promptTokens: extra?.usageMetadata?.input_tokens,
|
||||
completionTokens: extra?.usageMetadata?.output_tokens,
|
||||
totalTokens: extra?.usageMetadata?.total_tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function convertResponseContentToChatGenerationChunk(
|
||||
response: EnhancedGenerateContentResponse,
|
||||
extra: {
|
||||
usageMetadata?: UsageMetadata | undefined
|
||||
index: number
|
||||
}
|
||||
): ChatGenerationChunk | null {
|
||||
if (!response.candidates || response.candidates.length === 0) {
|
||||
return null
|
||||
}
|
||||
const functionCalls = response.functionCalls()
|
||||
const [candidate] = response.candidates
|
||||
const { content: candidateContent, ...generationInfo } = candidate
|
||||
let content: MessageContent | undefined
|
||||
// Checks if some parts do not have text. If false, it means that the content is a string.
|
||||
if (Array.isArray(candidateContent?.parts) && candidateContent.parts.every((p) => 'text' in p)) {
|
||||
content = candidateContent.parts.map((p) => p.text).join('')
|
||||
} else if (Array.isArray(candidateContent?.parts)) {
|
||||
content = candidateContent.parts.map((p) => {
|
||||
if ('text' in p) {
|
||||
return {
|
||||
type: 'text',
|
||||
text: p.text
|
||||
}
|
||||
} else if ('executableCode' in p) {
|
||||
return {
|
||||
type: 'executableCode',
|
||||
executableCode: p.executableCode
|
||||
}
|
||||
} else if ('codeExecutionResult' in p) {
|
||||
return {
|
||||
type: 'codeExecutionResult',
|
||||
codeExecutionResult: p.codeExecutionResult
|
||||
}
|
||||
}
|
||||
return p
|
||||
})
|
||||
} else {
|
||||
// no content returned - likely due to abnormal stop reason, e.g. malformed function call
|
||||
content = []
|
||||
}
|
||||
|
||||
let text = ''
|
||||
if (content && typeof content === 'string') {
|
||||
text = content
|
||||
} else if (Array.isArray(content)) {
|
||||
const block = content.find((b) => 'text' in b) as { text: string } | undefined
|
||||
text = block?.text ?? ''
|
||||
}
|
||||
|
||||
const toolCallChunks: ToolCallChunk[] = []
|
||||
if (functionCalls) {
|
||||
toolCallChunks.push(
|
||||
...functionCalls.map((fc) => ({
|
||||
...fc,
|
||||
args: JSON.stringify(fc.args),
|
||||
index: extra.index,
|
||||
type: 'tool_call_chunk' as const,
|
||||
id: 'id' in fc && typeof fc.id === 'string' ? fc.id : uuidv4()
|
||||
}))
|
||||
)
|
||||
}
|
||||
|
||||
return new ChatGenerationChunk({
|
||||
text,
|
||||
message: new AIMessageChunk({
|
||||
content: content || '',
|
||||
name: !candidateContent ? undefined : candidateContent.role,
|
||||
tool_call_chunks: toolCallChunks,
|
||||
// Each chunk can have unique "generationInfo", and merging strategy is unclear,
|
||||
// so leave blank for now.
|
||||
additional_kwargs: {},
|
||||
usage_metadata: extra.usageMetadata
|
||||
}),
|
||||
generationInfo
|
||||
})
|
||||
}
|
||||
|
||||
export function convertToGenerativeAITools(tools: GoogleGenerativeAIToolType[]): GoogleGenerativeAIFunctionDeclarationsTool[] {
|
||||
if (tools.every((tool) => 'functionDeclarations' in tool && Array.isArray(tool.functionDeclarations))) {
|
||||
return tools as GoogleGenerativeAIFunctionDeclarationsTool[]
|
||||
}
|
||||
return [
|
||||
{
|
||||
functionDeclarations: tools.map((tool): GenerativeAIFunctionDeclaration => {
|
||||
if (isLangChainTool(tool)) {
|
||||
const jsonSchema = schemaToGenerativeAIParameters(tool.schema)
|
||||
if (jsonSchema.type === 'object' && 'properties' in jsonSchema && Object.keys(jsonSchema.properties).length === 0) {
|
||||
return {
|
||||
name: tool.name,
|
||||
description: tool.description
|
||||
}
|
||||
}
|
||||
return {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: jsonSchema
|
||||
}
|
||||
}
|
||||
if (isOpenAITool(tool)) {
|
||||
return {
|
||||
name: tool.function.name,
|
||||
description: tool.function.description ?? `A function available to call.`,
|
||||
parameters: jsonSchemaToGeminiParameters(tool.function.parameters)
|
||||
}
|
||||
}
|
||||
return tool as unknown as GenerativeAIFunctionDeclaration
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
import { BaseLLMOutputParser, OutputParserException } from '@langchain/core/output_parsers'
|
||||
import { ChatGeneration } from '@langchain/core/outputs'
|
||||
import { ToolCall } from '@langchain/core/messages/tool'
|
||||
import { InteropZodType, interopSafeParseAsync } from '@langchain/core/utils/types'
|
||||
import { JsonOutputKeyToolsParserParamsInterop } from '@langchain/core/output_parsers/openai_tools'
|
||||
|
||||
interface GoogleGenerativeAIToolsOutputParserParams<T extends Record<string, any>> extends JsonOutputKeyToolsParserParamsInterop<T> {}
|
||||
|
||||
export class GoogleGenerativeAIToolsOutputParser<T extends Record<string, any> = Record<string, any>> extends BaseLLMOutputParser<T> {
|
||||
static lc_name() {
|
||||
return 'GoogleGenerativeAIToolsOutputParser'
|
||||
}
|
||||
|
||||
lc_namespace = ['langchain', 'google_genai', 'output_parsers']
|
||||
|
||||
returnId = false
|
||||
|
||||
/** The type of tool calls to return. */
|
||||
keyName: string
|
||||
|
||||
/** Whether to return only the first tool call. */
|
||||
returnSingle = false
|
||||
|
||||
zodSchema?: InteropZodType<T>
|
||||
|
||||
constructor(params: GoogleGenerativeAIToolsOutputParserParams<T>) {
|
||||
super(params)
|
||||
this.keyName = params.keyName
|
||||
this.returnSingle = params.returnSingle ?? this.returnSingle
|
||||
this.zodSchema = params.zodSchema
|
||||
}
|
||||
|
||||
protected async _validateResult(result: unknown): Promise<T> {
|
||||
if (this.zodSchema === undefined) {
|
||||
return result as T
|
||||
}
|
||||
const zodParsedResult = await interopSafeParseAsync(this.zodSchema, result)
|
||||
if (zodParsedResult.success) {
|
||||
return zodParsedResult.data
|
||||
} else {
|
||||
throw new OutputParserException(
|
||||
`Failed to parse. Text: "${JSON.stringify(result, null, 2)}". Error: ${JSON.stringify(zodParsedResult.error.issues)}`,
|
||||
JSON.stringify(result, null, 2)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async parseResult(generations: ChatGeneration[]): Promise<T> {
|
||||
const tools = generations.flatMap((generation) => {
|
||||
const { message } = generation
|
||||
if (!('tool_calls' in message) || !Array.isArray(message.tool_calls)) {
|
||||
return []
|
||||
}
|
||||
return message.tool_calls as ToolCall[]
|
||||
})
|
||||
if (tools[0] === undefined) {
|
||||
throw new Error('No parseable tool calls provided to GoogleGenerativeAIToolsOutputParser.')
|
||||
}
|
||||
const [tool] = tools
|
||||
const validatedResult = await this._validateResult(tool.args)
|
||||
return validatedResult
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
import {
|
||||
Tool as GenerativeAITool,
|
||||
ToolConfig,
|
||||
FunctionCallingMode,
|
||||
FunctionDeclaration,
|
||||
FunctionDeclarationsTool,
|
||||
FunctionDeclarationSchema
|
||||
} from '@google/generative-ai'
|
||||
import { ToolChoice } from '@langchain/core/language_models/chat_models'
|
||||
import { StructuredToolInterface } from '@langchain/core/tools'
|
||||
import { isLangChainTool } from '@langchain/core/utils/function_calling'
|
||||
import { isOpenAITool, ToolDefinition } from '@langchain/core/language_models/base'
|
||||
import { convertToGenerativeAITools } from './common.js'
|
||||
import { GoogleGenerativeAIToolType } from './types.js'
|
||||
import { removeAdditionalProperties } from './zod_to_genai_parameters.js'
|
||||
|
||||
export function convertToolsToGenAI(
|
||||
tools: GoogleGenerativeAIToolType[],
|
||||
extra?: {
|
||||
toolChoice?: ToolChoice
|
||||
allowedFunctionNames?: string[]
|
||||
}
|
||||
): {
|
||||
tools: GenerativeAITool[]
|
||||
toolConfig?: ToolConfig
|
||||
} {
|
||||
// Extract function declaration processing to a separate function
|
||||
const genAITools = processTools(tools)
|
||||
|
||||
// Simplify tool config creation
|
||||
const toolConfig = createToolConfig(genAITools, extra)
|
||||
|
||||
return { tools: genAITools, toolConfig }
|
||||
}
|
||||
|
||||
function processTools(tools: GoogleGenerativeAIToolType[]): GenerativeAITool[] {
|
||||
let functionDeclarationTools: FunctionDeclaration[] = []
|
||||
const genAITools: GenerativeAITool[] = []
|
||||
|
||||
tools.forEach((tool) => {
|
||||
if (isLangChainTool(tool)) {
|
||||
const [convertedTool] = convertToGenerativeAITools([tool as StructuredToolInterface])
|
||||
if (convertedTool.functionDeclarations) {
|
||||
functionDeclarationTools.push(...convertedTool.functionDeclarations)
|
||||
}
|
||||
} else if (isOpenAITool(tool)) {
|
||||
const { functionDeclarations } = convertOpenAIToolToGenAI(tool)
|
||||
if (functionDeclarations) {
|
||||
functionDeclarationTools.push(...functionDeclarations)
|
||||
} else {
|
||||
throw new Error('Failed to convert OpenAI structured tool to GenerativeAI tool')
|
||||
}
|
||||
} else {
|
||||
genAITools.push(tool as GenerativeAITool)
|
||||
}
|
||||
})
|
||||
|
||||
const genAIFunctionDeclaration = genAITools.find((t) => 'functionDeclarations' in t)
|
||||
if (genAIFunctionDeclaration) {
|
||||
return genAITools.map((tool) => {
|
||||
if (functionDeclarationTools?.length > 0 && 'functionDeclarations' in tool) {
|
||||
const newTool = {
|
||||
functionDeclarations: [...(tool.functionDeclarations || []), ...functionDeclarationTools]
|
||||
}
|
||||
// Clear the functionDeclarationTools array so it is not passed again
|
||||
functionDeclarationTools = []
|
||||
return newTool
|
||||
}
|
||||
return tool
|
||||
})
|
||||
}
|
||||
|
||||
return [
|
||||
...genAITools,
|
||||
...(functionDeclarationTools.length > 0
|
||||
? [
|
||||
{
|
||||
functionDeclarations: functionDeclarationTools
|
||||
}
|
||||
]
|
||||
: [])
|
||||
]
|
||||
}
|
||||
|
||||
function convertOpenAIToolToGenAI(tool: ToolDefinition): FunctionDeclarationsTool {
|
||||
return {
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: tool.function.name,
|
||||
description: tool.function.description,
|
||||
parameters: removeAdditionalProperties(tool.function.parameters) as FunctionDeclarationSchema
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
function createToolConfig(
|
||||
genAITools: GenerativeAITool[],
|
||||
extra?: {
|
||||
toolChoice?: ToolChoice
|
||||
allowedFunctionNames?: string[]
|
||||
}
|
||||
): ToolConfig | undefined {
|
||||
if (!genAITools.length || !extra) return undefined
|
||||
|
||||
const { toolChoice, allowedFunctionNames } = extra
|
||||
|
||||
const modeMap: Record<string, FunctionCallingMode> = {
|
||||
any: FunctionCallingMode.ANY,
|
||||
auto: FunctionCallingMode.AUTO,
|
||||
none: FunctionCallingMode.NONE
|
||||
}
|
||||
|
||||
if (toolChoice && ['any', 'auto', 'none'].includes(toolChoice as string)) {
|
||||
return {
|
||||
functionCallingConfig: {
|
||||
mode: modeMap[toolChoice as keyof typeof modeMap] ?? 'MODE_UNSPECIFIED',
|
||||
allowedFunctionNames
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (typeof toolChoice === 'string' || allowedFunctionNames) {
|
||||
return {
|
||||
functionCallingConfig: {
|
||||
mode: FunctionCallingMode.ANY,
|
||||
allowedFunctionNames: [
|
||||
...(allowedFunctionNames ?? []),
|
||||
...(toolChoice && typeof toolChoice === 'string' ? [toolChoice] : [])
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
import {
|
||||
CodeExecutionTool,
|
||||
FunctionDeclarationsTool as GoogleGenerativeAIFunctionDeclarationsTool,
|
||||
GoogleSearchRetrievalTool
|
||||
} from '@google/generative-ai'
|
||||
import { BindToolsInput } from '@langchain/core/language_models/chat_models'
|
||||
|
||||
export type GoogleGenerativeAIToolType =
|
||||
| BindToolsInput
|
||||
| GoogleGenerativeAIFunctionDeclarationsTool
|
||||
| CodeExecutionTool
|
||||
| GoogleSearchRetrievalTool
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import {
|
||||
type FunctionDeclarationSchema as GenerativeAIFunctionDeclarationSchema,
|
||||
type SchemaType as FunctionDeclarationSchemaType
|
||||
} from '@google/generative-ai'
|
||||
import { InteropZodType, isInteropZodSchema } from '@langchain/core/utils/types'
|
||||
import { type JsonSchema7Type, toJsonSchema } from '@langchain/core/utils/json_schema'
|
||||
|
||||
export interface GenerativeAIJsonSchema extends Record<string, unknown> {
|
||||
properties?: Record<string, GenerativeAIJsonSchema>
|
||||
type: FunctionDeclarationSchemaType
|
||||
}
|
||||
|
||||
export interface GenerativeAIJsonSchemaDirty extends GenerativeAIJsonSchema {
|
||||
properties?: Record<string, GenerativeAIJsonSchemaDirty>
|
||||
additionalProperties?: boolean
|
||||
}
|
||||
|
||||
export function removeAdditionalProperties(obj: Record<string, any>): GenerativeAIJsonSchema {
|
||||
if (typeof obj === 'object' && obj !== null) {
|
||||
const newObj = { ...obj }
|
||||
|
||||
if ('additionalProperties' in newObj) {
|
||||
delete newObj.additionalProperties
|
||||
}
|
||||
if ('$schema' in newObj) {
|
||||
delete newObj.$schema
|
||||
}
|
||||
if ('strict' in newObj) {
|
||||
delete newObj.strict
|
||||
}
|
||||
|
||||
for (const key in newObj) {
|
||||
if (key in newObj) {
|
||||
if (Array.isArray(newObj[key])) {
|
||||
newObj[key] = newObj[key].map(removeAdditionalProperties)
|
||||
} else if (typeof newObj[key] === 'object' && newObj[key] !== null) {
|
||||
newObj[key] = removeAdditionalProperties(newObj[key])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return newObj as GenerativeAIJsonSchema
|
||||
}
|
||||
|
||||
return obj as GenerativeAIJsonSchema
|
||||
}
|
||||
|
||||
export function schemaToGenerativeAIParameters<RunOutput extends Record<string, any> = Record<string, any>>(
|
||||
schema: InteropZodType<RunOutput> | JsonSchema7Type
|
||||
): GenerativeAIFunctionDeclarationSchema {
|
||||
// GenerativeAI doesn't accept either the $schema or additionalProperties
|
||||
// attributes, so we need to explicitly remove them.
|
||||
const jsonSchema = removeAdditionalProperties(isInteropZodSchema(schema) ? toJsonSchema(schema) : schema)
|
||||
const { _schema, ...rest } = jsonSchema
|
||||
|
||||
return rest as GenerativeAIFunctionDeclarationSchema
|
||||
}
|
||||
|
||||
export function jsonSchemaToGeminiParameters(schema: Record<string, any>): GenerativeAIFunctionDeclarationSchema {
|
||||
// Gemini doesn't accept either the $schema or additionalProperties
|
||||
// attributes, so we need to explicitly remove them.
|
||||
|
||||
const jsonSchema = removeAdditionalProperties(schema as GenerativeAIJsonSchemaDirty)
|
||||
const { _schema, ...rest } = jsonSchema
|
||||
|
||||
return rest as GenerativeAIFunctionDeclarationSchema
|
||||
}
|
||||
|
|
@ -132,7 +132,7 @@ export async function MCPTool({
|
|||
const client = await toolkit.createClient()
|
||||
|
||||
try {
|
||||
const req: CallToolRequest = { method: 'tools/call', params: { name: name, arguments: input } }
|
||||
const req: CallToolRequest = { method: 'tools/call', params: { name: name, arguments: input as any } }
|
||||
const res = await client.request(req, CallToolResultSchema)
|
||||
const content = res.content
|
||||
const contentString = JSON.stringify(content)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@
|
|||
"@langchain/baidu-qianfan": "^0.1.0",
|
||||
"@langchain/cohere": "^0.0.7",
|
||||
"@langchain/community": "^0.3.29",
|
||||
"@langchain/core": "0.3.37",
|
||||
"@langchain/core": "0.3.61",
|
||||
"@langchain/exa": "^0.0.5",
|
||||
"@langchain/google-genai": "0.2.3",
|
||||
"@langchain/google-vertexai": "^0.2.0",
|
||||
|
|
|
|||
|
|
@ -308,7 +308,7 @@ const _generateSelectedTools = async (config: Record<string, any>, question: str
|
|||
const model = (await newToolNodeInstance.init(config.selectedChatModel, '', options)) as BaseChatModel
|
||||
|
||||
// Create a parser to validate the output
|
||||
const parser = StructuredOutputParser.fromZodSchema(ToolType)
|
||||
const parser = StructuredOutputParser.fromZodSchema(ToolType as any)
|
||||
|
||||
// Generate JSON schema from our Zod schema
|
||||
const formatInstructions = parser.getFormatInstructions()
|
||||
|
|
@ -364,7 +364,7 @@ const generateNodesEdges = async (config: Record<string, any>, question: string,
|
|||
const model = (await newToolNodeInstance.init(config.selectedChatModel, '', options)) as BaseChatModel
|
||||
|
||||
// Create a parser to validate the output
|
||||
const parser = StructuredOutputParser.fromZodSchema(NodesEdgesType)
|
||||
const parser = StructuredOutputParser.fromZodSchema(NodesEdgesType as any)
|
||||
|
||||
// Generate JSON schema from our Zod schema
|
||||
const formatInstructions = parser.getFormatInstructions()
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ export const generateFollowUpPrompts = async (
|
|||
temperature: parseFloat(`${providerConfig.temperature}`)
|
||||
})
|
||||
// use structured output parser because withStructuredOutput is not working
|
||||
const parser = StructuredOutputParser.fromZodSchema(FollowUpPromptType)
|
||||
const parser = StructuredOutputParser.fromZodSchema(FollowUpPromptType as any)
|
||||
const formatInstructions = parser.getFormatInstructions()
|
||||
const prompt = PromptTemplate.fromTemplate(`
|
||||
${providerConfig.prompt}
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 112 KiB |
|
|
@ -15,6 +15,7 @@ import azureOpenAiIcon from '@/assets/images/azure_openai.svg'
|
|||
import mistralAiIcon from '@/assets/images/mistralai.svg'
|
||||
import openAiIcon from '@/assets/images/openai.svg'
|
||||
import groqIcon from '@/assets/images/groq.png'
|
||||
import geminiIcon from '@/assets/images/gemini.png'
|
||||
import ollamaIcon from '@/assets/images/ollama.svg'
|
||||
import { TooltipWithParser } from '@/ui-component/tooltip/TooltipWithParser'
|
||||
import CredentialInputHandler from '@/views/canvas/CredentialInputHandler'
|
||||
|
|
@ -117,7 +118,7 @@ const followUpPromptsOptions = {
|
|||
[FollowUpPromptProviders.GOOGLE_GENAI]: {
|
||||
label: 'Google Gemini',
|
||||
name: FollowUpPromptProviders.GOOGLE_GENAI,
|
||||
icon: azureOpenAiIcon,
|
||||
icon: geminiIcon,
|
||||
inputs: [
|
||||
{
|
||||
label: 'Connect Credential',
|
||||
|
|
@ -128,12 +129,8 @@ const followUpPromptsOptions = {
|
|||
{
|
||||
label: 'Model Name',
|
||||
name: 'modelName',
|
||||
type: 'options',
|
||||
default: 'gemini-1.5-pro-latest',
|
||||
options: [
|
||||
{ label: 'gemini-1.5-flash-latest', name: 'gemini-1.5-flash-latest' },
|
||||
{ label: 'gemini-1.5-pro-latest', name: 'gemini-1.5-pro-latest' }
|
||||
]
|
||||
type: 'asyncOptions',
|
||||
loadMethod: 'listModels'
|
||||
},
|
||||
{
|
||||
label: 'Prompt',
|
||||
|
|
@ -204,11 +201,8 @@ const followUpPromptsOptions = {
|
|||
{
|
||||
label: 'Model Name',
|
||||
name: 'modelName',
|
||||
type: 'options',
|
||||
options: [
|
||||
{ label: 'mistral-large-latest', name: 'mistral-large-latest' },
|
||||
{ label: 'mistral-large-2402', name: 'mistral-large-2402' }
|
||||
]
|
||||
type: 'asyncOptions',
|
||||
loadMethod: 'listModels'
|
||||
},
|
||||
{
|
||||
label: 'Prompt',
|
||||
|
|
|
|||
272
pnpm-lock.yaml
272
pnpm-lock.yaml
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue