diff --git a/packages/components/nodes/cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager.ts b/packages/components/nodes/cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager.ts new file mode 100644 index 000000000..dee516064 --- /dev/null +++ b/packages/components/nodes/cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager.ts @@ -0,0 +1,44 @@ +import type { CachedContentBase, CachedContent, Content } from '@google/generative-ai' +import { GoogleAICacheManager as GoogleAICacheManagerBase } from '@google/generative-ai/server' +import hash from 'object-hash' + +type CacheContentOptions = Omit & { contents?: Content[] } + +export class GoogleAICacheManager extends GoogleAICacheManagerBase { + private ttlSeconds: number + private cachedContents: Map = new Map() + + setTtlSeconds(ttlSeconds: number) { + this.ttlSeconds = ttlSeconds + } + + async lookup(options: CacheContentOptions): Promise { + const { model, tools, contents } = options + if (!contents?.length) { + return undefined + } + const hashKey = hash({ + model, + tools, + contents + }) + if (this.cachedContents.has(hashKey)) { + return this.cachedContents.get(hashKey) + } + const { cachedContents } = await this.list() + const cachedContent = (cachedContents ?? []).find((cache) => cache.displayName === hashKey) + if (cachedContent) { + this.cachedContents.set(hashKey, cachedContent) + return cachedContent + } + const res = await this.create({ + ...(options as CachedContentBase), + displayName: hashKey, + ttlSeconds: this.ttlSeconds + }) + this.cachedContents.set(hashKey, res) + return res + } +} + +export default GoogleAICacheManager diff --git a/packages/components/nodes/cache/GoogleGenerativeAIContextCache/GoogleGemini.svg b/packages/components/nodes/cache/GoogleGenerativeAIContextCache/GoogleGemini.svg new file mode 100644 index 000000000..53b497fa1 --- /dev/null +++ b/packages/components/nodes/cache/GoogleGenerativeAIContextCache/GoogleGemini.svg @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/packages/components/nodes/cache/GoogleGenerativeAIContextCache/GoogleGenerativeAIContextCache.ts b/packages/components/nodes/cache/GoogleGenerativeAIContextCache/GoogleGenerativeAIContextCache.ts new file mode 100644 index 000000000..9e6d28317 --- /dev/null +++ b/packages/components/nodes/cache/GoogleGenerativeAIContextCache/GoogleGenerativeAIContextCache.ts @@ -0,0 +1,53 @@ +import { getBaseClasses, getCredentialData, getCredentialParam, ICommonObject, INode, INodeData, INodeParams } from '../../../src' +import FlowiseGoogleAICacheManager from './FlowiseGoogleAICacheManager' + +class GoogleGenerativeAIContextCache implements INode { + label: string + name: string + version: number + description: string + type: string + icon: string + category: string + baseClasses: string[] + inputs: INodeParams[] + credential: INodeParams + + constructor() { + this.label = 'Google GenAI Context Cache' + this.name = 'googleGenerativeAIContextCache' + this.version = 1.0 + this.type = 'GoogleAICacheManager' + this.description = 'Large context cache for Google Gemini large language models' + this.icon = 'GoogleGemini.svg' + this.category = 'Cache' + this.baseClasses = [this.type, ...getBaseClasses(FlowiseGoogleAICacheManager)] + this.inputs = [ + { + label: 'TTL', + name: 'ttl', + type: 'number', + default: 60 * 60 * 24 * 30 + } + ] + this.credential = { + label: 'Connect Credential', + name: 'credential', + type: 'credential', + credentialNames: ['googleGenerativeAI'], + optional: false, + description: 'Google Generative AI credential.' + } + } + + async init(nodeData: INodeData, _: string, options: ICommonObject): Promise { + const ttl = nodeData.inputs?.ttl as number + const credentialData = await getCredentialData(nodeData.credential ?? '', options) + const apiKey = getCredentialParam('googleGenerativeAPIKey', credentialData, nodeData) + const manager = new FlowiseGoogleAICacheManager(apiKey) + manager.setTtlSeconds(ttl) + return manager + } +} + +module.exports = { nodeClass: GoogleGenerativeAIContextCache } diff --git a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts index 3b13ab271..836f97039 100644 --- a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts +++ b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts @@ -5,6 +5,7 @@ import { ICommonObject, IMultiModalOption, INode, INodeData, INodeOptionsValue, import { convertMultiOptionsToStringArray, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' import { getModels, MODEL_TYPE } from '../../../src/modelLoader' import { ChatGoogleGenerativeAI, GoogleGenerativeAIChatInput } from './FlowiseChatGoogleGenerativeAI' +import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager' class GoogleGenerativeAI_ChatModels implements INode { label: string @@ -42,6 +43,12 @@ class GoogleGenerativeAI_ChatModels implements INode { type: 'BaseCache', optional: true }, + { + label: 'Context Cache', + name: 'contextCache', + type: 'GoogleAICacheManager', + optional: true + }, { label: 'Model Name', name: 'modelName', @@ -188,6 +195,7 @@ class GoogleGenerativeAI_ChatModels implements INode { const harmCategory = nodeData.inputs?.harmCategory as string const harmBlockThreshold = nodeData.inputs?.harmBlockThreshold as string const cache = nodeData.inputs?.cache as BaseCache + const contextCache = nodeData.inputs?.contextCache as FlowiseGoogleAICacheManager const streaming = nodeData.inputs?.streaming as boolean const allowImageUploads = nodeData.inputs?.allowImageUploads as boolean @@ -225,6 +233,7 @@ class GoogleGenerativeAI_ChatModels implements INode { const model = new ChatGoogleGenerativeAI(nodeData.id, obj) model.setMultiModalOption(multiModalOption) + if (contextCache) model.setContextCache(contextCache) return model } diff --git a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI.ts b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI.ts index f8bb71110..51dcbcd91 100644 --- a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI.ts +++ b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI.ts @@ -25,6 +25,7 @@ import { StructuredToolInterface } from '@langchain/core/tools' import { isStructuredTool } from '@langchain/core/utils/function_calling' import { zodToJsonSchema } from 'zod-to-json-schema' import { BaseLanguageModelCallOptions } from '@langchain/core/language_models/base' +import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager' const DEFAULT_IMAGE_MAX_TOKEN = 8192 const DEFAULT_IMAGE_MODEL = 'gemini-1.5-flash-latest' @@ -86,6 +87,8 @@ class LangchainChatGoogleGenerativeAI private client: GenerativeModel + private contextCache?: FlowiseGoogleAICacheManager + get _isMultimodalModel() { return this.modelName.includes('vision') || this.modelName.startsWith('gemini-1.5') } @@ -147,7 +150,7 @@ class LangchainChatGoogleGenerativeAI this.getClient() } - getClient(tools?: Tool[]) { + async getClient(prompt?: Content[], tools?: Tool[]) { this.client = new GenerativeAI(this.apiKey ?? '').getGenerativeModel({ model: this.modelName, tools, @@ -161,6 +164,14 @@ class LangchainChatGoogleGenerativeAI topK: this.topK } }) + if (this.contextCache) { + const cachedContent = await this.contextCache.lookup({ + contents: prompt ? [{ ...prompt[0], parts: prompt[0].parts.slice(0, 1) }] : [], + model: this.modelName, + tools + }) + this.client.cachedContent = cachedContent as any + } } _combineLLMOutput() { @@ -209,6 +220,10 @@ class LangchainChatGoogleGenerativeAI } } + setContextCache(contextCache: FlowiseGoogleAICacheManager): void { + this.contextCache = contextCache + } + async getNumTokens(prompt: BaseMessage[]) { const contents = convertBaseMessagesToContent(prompt, this._isMultimodalModel) const { totalTokens } = await this.client.countTokens({ contents }) @@ -226,9 +241,9 @@ class LangchainChatGoogleGenerativeAI this.convertFunctionResponse(prompt) if (tools.length > 0) { - this.getClient(tools as Tool[]) + await this.getClient(prompt, tools as Tool[]) } else { - this.getClient() + await this.getClient(prompt) } const res = await this.caller.callWithOptions({ signal: options?.signal }, async () => { let output @@ -296,9 +311,9 @@ class LangchainChatGoogleGenerativeAI const tools = options.tools ?? [] if (tools.length > 0) { - this.getClient(tools as Tool[]) + await this.getClient(prompt, tools as Tool[]) } else { - this.getClient() + await this.getClient(prompt) } const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => {