diff --git a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts
index 546fa224c..9a4b8891b 100644
--- a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts
+++ b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts
@@ -1,7 +1,9 @@
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
-import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
+import { convertMultiOptionsToStringArray, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
import { BaseCache } from 'langchain/schema'
-import { ChatGoogleGenerativeAI } from '@langchain/google-genai'
+import { ChatGoogleGenerativeAI, GoogleGenerativeAIChatInput } from '@langchain/google-genai'
+import { HarmBlockThreshold, HarmCategory } from '@google/generative-ai'
+import type { SafetySetting } from '@google/generative-ai'
class GoogleGenerativeAI_ChatModels implements INode {
label: string
@@ -74,6 +76,73 @@ class GoogleGenerativeAI_ChatModels implements INode {
step: 0.1,
optional: true,
additionalParams: true
+ },
+ {
+ label: 'Top Next Highest Probability Tokens',
+ name: 'topK',
+ type: 'number',
+ description: `Decode using top-k sampling: consider the set of top_k most probable tokens. Must be positive`,
+ step: 1,
+ optional: true,
+ additionalParams: true
+ },
+ {
+ label: 'Harm Category',
+ name: 'harmCategory',
+ type: 'multiOptions',
+ description:
+ 'Refer to official guide on how to use Harm Category',
+ options: [
+ {
+ label: 'Dangerous',
+ name: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
+ },
+ {
+ label: 'Harassment',
+ name: HarmCategory.HARM_CATEGORY_HARASSMENT
+ },
+ {
+ label: 'Hate Speech',
+ name: HarmCategory.HARM_CATEGORY_HATE_SPEECH
+ },
+ {
+ label: 'Sexually Explicit',
+ name: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
+ }
+ ],
+ optional: true,
+ additionalParams: true
+ },
+ {
+ label: 'Harm Block Threshold',
+ name: 'harmBlockThreshold',
+ type: 'multiOptions',
+ description:
+ 'Refer to official guide on how to use Harm Block Threshold',
+ options: [
+ {
+ label: 'Low and Above',
+ name: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
+ },
+ {
+ label: 'Medium and Above',
+ name: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
+ },
+ {
+ label: 'None',
+ name: HarmBlockThreshold.BLOCK_NONE
+ },
+ {
+ label: 'Only High',
+ name: HarmBlockThreshold.BLOCK_ONLY_HIGH
+ },
+ {
+ label: 'Threshold Unspecified',
+ name: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED
+ }
+ ],
+ optional: true,
+ additionalParams: true
}
]
}
@@ -86,9 +155,12 @@ class GoogleGenerativeAI_ChatModels implements INode {
const modelName = nodeData.inputs?.modelName as string
const maxOutputTokens = nodeData.inputs?.maxOutputTokens as string
const topP = nodeData.inputs?.topP as string
+ const topK = nodeData.inputs?.topK as string
+ const harmCategory = nodeData.inputs?.harmCategory as string
+ const harmBlockThreshold = nodeData.inputs?.harmBlockThreshold as string
const cache = nodeData.inputs?.cache as BaseCache
- const obj = {
+ const obj: Partial = {
apiKey: apiKey,
modelName: modelName,
maxOutputTokens: 2048
@@ -98,8 +170,23 @@ class GoogleGenerativeAI_ChatModels implements INode {
const model = new ChatGoogleGenerativeAI(obj)
if (topP) model.topP = parseFloat(topP)
+ if (topK) model.topK = parseFloat(topK)
if (cache) model.cache = cache
if (temperature) model.temperature = parseFloat(temperature)
+
+ // Safety Settings
+ let harmCategories: string[] = convertMultiOptionsToStringArray(harmCategory)
+ let harmBlockThresholds: string[] = convertMultiOptionsToStringArray(harmBlockThreshold)
+ if (harmCategories.length != harmBlockThresholds.length)
+ throw new Error(`Harm Category & Harm Block Threshold are not the same length`)
+ const safetySettings: SafetySetting[] = harmCategories.map((harmCategory, index) => {
+ return {
+ category: harmCategory as HarmCategory,
+ threshold: harmBlockThresholds[index] as HarmBlockThreshold
+ }
+ })
+ if (safetySettings.length > 0) model.safetySettings = safetySettings
+
return model
}
}
diff --git a/packages/components/package.json b/packages/components/package.json
index ddf093998..894014d42 100644
--- a/packages/components/package.json
+++ b/packages/components/package.json
@@ -26,6 +26,7 @@
"@gomomento/sdk": "^1.51.1",
"@gomomento/sdk-core": "^1.51.1",
"@google-ai/generativelanguage": "^0.2.1",
+ "@google/generative-ai": "^0.1.3",
"@huggingface/inference": "^2.6.1",
"@langchain/community": "^0.0.16",
"@langchain/google-genai": "^0.0.6",
diff --git a/packages/components/src/utils.ts b/packages/components/src/utils.ts
index 22fa6f4a9..2215eb418 100644
--- a/packages/components/src/utils.ts
+++ b/packages/components/src/utils.ts
@@ -673,3 +673,18 @@ export const convertBaseMessagetoIMessage = (messages: BaseMessage[]): IMessage[
}
return formatmessages
}
+
+/**
+ * Convert MultiOptions String to String Array
+ * @param {string} inputString
+ * @returns {string[]}
+ */
+export const convertMultiOptionsToStringArray = (inputString: string): string[] => {
+ let ArrayString: string[] = []
+ try {
+ ArrayString = JSON.parse(inputString)
+ } catch (e) {
+ ArrayString = []
+ }
+ return ArrayString
+}