diff --git a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts index 546fa224c..311b4e8db 100644 --- a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts +++ b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts @@ -1,7 +1,8 @@ import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' -import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' +import { convertStringToArrayString, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' import { BaseCache } from 'langchain/schema' import { ChatGoogleGenerativeAI } from '@langchain/google-genai' +import { HarmBlockThreshold, HarmCategory } from '@google/generative-ai' class GoogleGenerativeAI_ChatModels implements INode { label: string @@ -74,6 +75,72 @@ class GoogleGenerativeAI_ChatModels implements INode { step: 0.1, optional: true, additionalParams: true + }, + { + label: 'topK', + name: 'topK', + type: 'number', + step: 0.1, + optional: true, + additionalParams: true + }, + { + label: 'Harm Category', + name: 'harmCategory', + type: 'multiOptions', + description: + 'Refer to official guide on how to use Harm Category', + options: [ + { + label: 'Dangerous', + name: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + }, + { + label: 'Harassment', + name: HarmCategory.HARM_CATEGORY_HARASSMENT + }, + { + label: 'Hate Speech', + name: HarmCategory.HARM_CATEGORY_HATE_SPEECH + }, + { + label: 'Sexually Explicit', + name: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT + } + ], + optional: true, + additionalParams: true + }, + { + label: 'Harm Block Threshold', + name: 'harmBlockThreshold', + type: 'multiOptions', + description: + 'Refer to official guide on how to use Harm Block Threshold', + options: [ + { + label: 'Low and Above', + name: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + }, + { + label: 'Medium and Above', + name: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE + }, + { + label: 'None', + name: HarmBlockThreshold.BLOCK_NONE + }, + { + label: 'Only High', + name: HarmBlockThreshold.BLOCK_ONLY_HIGH + }, + { + label: 'Threshold Unspecified', + name: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED + } + ], + optional: true, + additionalParams: true } ] } @@ -86,6 +153,9 @@ class GoogleGenerativeAI_ChatModels implements INode { const modelName = nodeData.inputs?.modelName as string const maxOutputTokens = nodeData.inputs?.maxOutputTokens as string const topP = nodeData.inputs?.topP as string + const topK = nodeData.inputs?.topK as string + const harmCategory = nodeData.inputs?.harmCategory as string + const harmBlockThreshold = nodeData.inputs?.harmBlockThreshold as string const cache = nodeData.inputs?.cache as BaseCache const obj = { @@ -98,8 +168,23 @@ class GoogleGenerativeAI_ChatModels implements INode { const model = new ChatGoogleGenerativeAI(obj) if (topP) model.topP = parseFloat(topP) + if (topK) model.topP = parseFloat(topK) if (cache) model.cache = cache if (temperature) model.temperature = parseFloat(temperature) + + // safetySettings + let harmCategories: string[] = convertStringToArrayString(harmCategory) + let harmBlockThresholds: string[] = convertStringToArrayString(harmBlockThreshold) + if (harmCategories.length != harmBlockThresholds.length) + throw new Error(`Harm Category & Harm Block Threshold are not the same length`) + const safetySettings = harmCategories.map((value, index) => { + return { + category: value, + threshold: harmBlockThresholds[index] + } + }) + if (safetySettings) model.safetySettings = safetySettings + return model } } diff --git a/packages/components/src/utils.ts b/packages/components/src/utils.ts index 22fa6f4a9..88c553cfc 100644 --- a/packages/components/src/utils.ts +++ b/packages/components/src/utils.ts @@ -673,3 +673,20 @@ export const convertBaseMessagetoIMessage = (messages: BaseMessage[]): IMessage[ } return formatmessages } + +/** + * Convert String to Array String + * @param {string} inputString + * @returns {string[]} + */ +export const convertStringToArrayString = (inputString: string): string[] => { + let ArrayString: string[] = [] + if (inputString) { + try { + ArrayString = JSON.parse(inputString) + } catch (e) { + ArrayString = [] + } + } + return ArrayString +}