diff --git a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts index bd660b474..9a4b8891b 100644 --- a/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts +++ b/packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts @@ -78,10 +78,11 @@ class GoogleGenerativeAI_ChatModels implements INode { additionalParams: true }, { - label: 'topK', + label: 'Top Next Highest Probability Tokens', name: 'topK', type: 'number', - step: 0.1, + description: `Decode using top-k sampling: consider the set of top_k most probable tokens. Must be positive`, + step: 1, optional: true, additionalParams: true }, @@ -90,7 +91,7 @@ class GoogleGenerativeAI_ChatModels implements INode { name: 'harmCategory', type: 'multiOptions', description: - 'Refer to official guide on how to use Harm Category', + 'Refer to official guide on how to use Harm Category', options: [ { label: 'Dangerous', @@ -117,7 +118,7 @@ class GoogleGenerativeAI_ChatModels implements INode { name: 'harmBlockThreshold', type: 'multiOptions', description: - 'Refer to official guide on how to use Harm Block Threshold', + 'Refer to official guide on how to use Harm Block Threshold', options: [ { label: 'Low and Above', @@ -169,7 +170,7 @@ class GoogleGenerativeAI_ChatModels implements INode { const model = new ChatGoogleGenerativeAI(obj) if (topP) model.topP = parseFloat(topP) - if (topK) model.topP = parseFloat(topK) + if (topK) model.topK = parseFloat(topK) if (cache) model.cache = cache if (temperature) model.temperature = parseFloat(temperature) @@ -178,10 +179,10 @@ class GoogleGenerativeAI_ChatModels implements INode { let harmBlockThresholds: string[] = convertMultiOptionsToStringArray(harmBlockThreshold) if (harmCategories.length != harmBlockThresholds.length) throw new Error(`Harm Category & Harm Block Threshold are not the same length`) - const safetySettings: SafetySetting[] = harmCategories.map((value, index) => { + const safetySettings: SafetySetting[] = harmCategories.map((harmCategory, index) => { return { - category: categoryInput(value), - threshold: thresholdInput(harmBlockThresholds[index]) + category: harmCategory as HarmCategory, + threshold: harmBlockThresholds[index] as HarmBlockThreshold } }) if (safetySettings.length > 0) model.safetySettings = safetySettings @@ -190,46 +191,4 @@ class GoogleGenerativeAI_ChatModels implements INode { } } -const categoryInput = (categoryInput: string): HarmCategory => { - let categoryOutput: HarmCategory - switch (categoryInput) { - case HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: - categoryOutput = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT - break - case HarmCategory.HARM_CATEGORY_HATE_SPEECH: - categoryOutput = HarmCategory.HARM_CATEGORY_HATE_SPEECH - break - case HarmCategory.HARM_CATEGORY_HARASSMENT: - categoryOutput = HarmCategory.HARM_CATEGORY_HARASSMENT - break - case HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: - categoryOutput = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT - break - default: - categoryOutput = HarmCategory.HARM_CATEGORY_UNSPECIFIED - } - return categoryOutput -} - -const thresholdInput = (thresholdInput: string): HarmBlockThreshold => { - let thresholdOutput: HarmBlockThreshold - switch (thresholdInput) { - case HarmBlockThreshold.BLOCK_LOW_AND_ABOVE: - thresholdOutput = HarmBlockThreshold.BLOCK_LOW_AND_ABOVE - break - case HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: - thresholdOutput = HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE - break - case HarmBlockThreshold.BLOCK_NONE: - thresholdOutput = HarmBlockThreshold.BLOCK_NONE - break - case HarmBlockThreshold.BLOCK_ONLY_HIGH: - thresholdOutput = HarmBlockThreshold.BLOCK_ONLY_HIGH - break - default: - thresholdOutput = HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED - } - return thresholdOutput -} - module.exports = { nodeClass: GoogleGenerativeAI_ChatModels } diff --git a/packages/components/src/utils.ts b/packages/components/src/utils.ts index 2d983562e..eacfa4a0b 100644 --- a/packages/components/src/utils.ts +++ b/packages/components/src/utils.ts @@ -681,12 +681,7 @@ export const convertBaseMessagetoIMessage = (messages: BaseMessage[]): IMessage[ */ export const convertMultiOptionsToStringArray = (inputString: string): string[] => { let ArrayString: string[] = [] - if (inputString) { - try { - ArrayString = JSON.parse(inputString) - } catch (e) { - ArrayString = [] - } - } + if (inputString) ArrayString = JSON.parse(inputString) + return ArrayString }