diff --git a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts index 546fa224c..9a4b8891b 100644 --- a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts +++ b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts @@ -1,7 +1,9 @@ import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' -import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' +import { convertMultiOptionsToStringArray, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' import { BaseCache } from 'langchain/schema' -import { ChatGoogleGenerativeAI } from '@langchain/google-genai' +import { ChatGoogleGenerativeAI, GoogleGenerativeAIChatInput } from '@langchain/google-genai' +import { HarmBlockThreshold, HarmCategory } from '@google/generative-ai' +import type { SafetySetting } from '@google/generative-ai' class GoogleGenerativeAI_ChatModels implements INode { label: string @@ -74,6 +76,73 @@ class GoogleGenerativeAI_ChatModels implements INode { step: 0.1, optional: true, additionalParams: true + }, + { + label: 'Top Next Highest Probability Tokens', + name: 'topK', + type: 'number', + description: `Decode using top-k sampling: consider the set of top_k most probable tokens. Must be positive`, + step: 1, + optional: true, + additionalParams: true + }, + { + label: 'Harm Category', + name: 'harmCategory', + type: 'multiOptions', + description: + 'Refer to official guide on how to use Harm Category', + options: [ + { + label: 'Dangerous', + name: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + }, + { + label: 'Harassment', + name: HarmCategory.HARM_CATEGORY_HARASSMENT + }, + { + label: 'Hate Speech', + name: HarmCategory.HARM_CATEGORY_HATE_SPEECH + }, + { + label: 'Sexually Explicit', + name: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT + } + ], + optional: true, + additionalParams: true + }, + { + label: 'Harm Block Threshold', + name: 'harmBlockThreshold', + type: 'multiOptions', + description: + 'Refer to official guide on how to use Harm Block Threshold', + options: [ + { + label: 'Low and Above', + name: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + }, + { + label: 'Medium and Above', + name: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE + }, + { + label: 'None', + name: HarmBlockThreshold.BLOCK_NONE + }, + { + label: 'Only High', + name: HarmBlockThreshold.BLOCK_ONLY_HIGH + }, + { + label: 'Threshold Unspecified', + name: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED + } + ], + optional: true, + additionalParams: true } ] } @@ -86,9 +155,12 @@ class GoogleGenerativeAI_ChatModels implements INode { const modelName = nodeData.inputs?.modelName as string const maxOutputTokens = nodeData.inputs?.maxOutputTokens as string const topP = nodeData.inputs?.topP as string + const topK = nodeData.inputs?.topK as string + const harmCategory = nodeData.inputs?.harmCategory as string + const harmBlockThreshold = nodeData.inputs?.harmBlockThreshold as string const cache = nodeData.inputs?.cache as BaseCache - const obj = { + const obj: Partial = { apiKey: apiKey, modelName: modelName, maxOutputTokens: 2048 @@ -98,8 +170,23 @@ class GoogleGenerativeAI_ChatModels implements INode { const model = new ChatGoogleGenerativeAI(obj) if (topP) model.topP = parseFloat(topP) + if (topK) model.topK = parseFloat(topK) if (cache) model.cache = cache if (temperature) model.temperature = parseFloat(temperature) + + // Safety Settings + let harmCategories: string[] = convertMultiOptionsToStringArray(harmCategory) + let harmBlockThresholds: string[] = convertMultiOptionsToStringArray(harmBlockThreshold) + if (harmCategories.length != harmBlockThresholds.length) + throw new Error(`Harm Category & Harm Block Threshold are not the same length`) + const safetySettings: SafetySetting[] = harmCategories.map((harmCategory, index) => { + return { + category: harmCategory as HarmCategory, + threshold: harmBlockThresholds[index] as HarmBlockThreshold + } + }) + if (safetySettings.length > 0) model.safetySettings = safetySettings + return model } } diff --git a/packages/components/package.json b/packages/components/package.json index ddf093998..894014d42 100644 --- a/packages/components/package.json +++ b/packages/components/package.json @@ -26,6 +26,7 @@ "@gomomento/sdk": "^1.51.1", "@gomomento/sdk-core": "^1.51.1", "@google-ai/generativelanguage": "^0.2.1", + "@google/generative-ai": "^0.1.3", "@huggingface/inference": "^2.6.1", "@langchain/community": "^0.0.16", "@langchain/google-genai": "^0.0.6", diff --git a/packages/components/src/utils.ts b/packages/components/src/utils.ts index 22fa6f4a9..2215eb418 100644 --- a/packages/components/src/utils.ts +++ b/packages/components/src/utils.ts @@ -673,3 +673,18 @@ export const convertBaseMessagetoIMessage = (messages: BaseMessage[]): IMessage[ } return formatmessages } + +/** + * Convert MultiOptions String to String Array + * @param {string} inputString + * @returns {string[]} + */ +export const convertMultiOptionsToStringArray = (inputString: string): string[] => { + let ArrayString: string[] = [] + try { + ArrayString = JSON.parse(inputString) + } catch (e) { + ArrayString = [] + } + return ArrayString +}