add TopK and safetySettings

This commit is contained in:
chungyau97 2024-01-13 16:18:35 +08:00
parent 1d8d97c7a1
commit c15250f28b
2 changed files with 103 additions and 1 deletions

View File

@ -1,7 +1,8 @@
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
import { convertStringToArrayString, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
import { BaseCache } from 'langchain/schema'
import { ChatGoogleGenerativeAI } from '@langchain/google-genai'
import { HarmBlockThreshold, HarmCategory } from '@google/generative-ai'
class GoogleGenerativeAI_ChatModels implements INode {
label: string
@ -74,6 +75,72 @@ class GoogleGenerativeAI_ChatModels implements INode {
step: 0.1,
optional: true,
additionalParams: true
},
{
label: 'topK',
name: 'topK',
type: 'number',
step: 0.1,
optional: true,
additionalParams: true
},
{
label: 'Harm Category',
name: 'harmCategory',
type: 'multiOptions',
description:
'Refer to <a target="_blank" href="https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/configure-safety-attributes#gemini-TASK-samples-go">official guide</a> on how to use Harm Category',
options: [
{
label: 'Dangerous',
name: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
},
{
label: 'Harassment',
name: HarmCategory.HARM_CATEGORY_HARASSMENT
},
{
label: 'Hate Speech',
name: HarmCategory.HARM_CATEGORY_HATE_SPEECH
},
{
label: 'Sexually Explicit',
name: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
}
],
optional: true,
additionalParams: true
},
{
label: 'Harm Block Threshold',
name: 'harmBlockThreshold',
type: 'multiOptions',
description:
'Refer to <a target="_blank" href="https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/configure-safety-attributes#gemini-TASK-samples-go">official guide</a> on how to use Harm Block Threshold',
options: [
{
label: 'Low and Above',
name: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
},
{
label: 'Medium and Above',
name: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
},
{
label: 'None',
name: HarmBlockThreshold.BLOCK_NONE
},
{
label: 'Only High',
name: HarmBlockThreshold.BLOCK_ONLY_HIGH
},
{
label: 'Threshold Unspecified',
name: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED
}
],
optional: true,
additionalParams: true
}
]
}
@ -86,6 +153,9 @@ class GoogleGenerativeAI_ChatModels implements INode {
const modelName = nodeData.inputs?.modelName as string
const maxOutputTokens = nodeData.inputs?.maxOutputTokens as string
const topP = nodeData.inputs?.topP as string
const topK = nodeData.inputs?.topK as string
const harmCategory = nodeData.inputs?.harmCategory as string
const harmBlockThreshold = nodeData.inputs?.harmBlockThreshold as string
const cache = nodeData.inputs?.cache as BaseCache
const obj = {
@ -98,8 +168,23 @@ class GoogleGenerativeAI_ChatModels implements INode {
const model = new ChatGoogleGenerativeAI(obj)
if (topP) model.topP = parseFloat(topP)
if (topK) model.topP = parseFloat(topK)
if (cache) model.cache = cache
if (temperature) model.temperature = parseFloat(temperature)
// safetySettings
let harmCategories: string[] = convertStringToArrayString(harmCategory)
let harmBlockThresholds: string[] = convertStringToArrayString(harmBlockThreshold)
if (harmCategories.length != harmBlockThresholds.length)
throw new Error(`Harm Category & Harm Block Threshold are not the same length`)
const safetySettings = harmCategories.map((value, index) => {
return {
category: value,
threshold: harmBlockThresholds[index]
}
})
if (safetySettings) model.safetySettings = safetySettings
return model
}
}

View File

@ -673,3 +673,20 @@ export const convertBaseMessagetoIMessage = (messages: BaseMessage[]): IMessage[
}
return formatmessages
}
/**
* Convert String to Array String
* @param {string} inputString
* @returns {string[]}
*/
export const convertStringToArrayString = (inputString: string): string[] => {
let ArrayString: string[] = []
if (inputString) {
try {
ArrayString = JSON.parse(inputString)
} catch (e) {
ArrayString = []
}
}
return ArrayString
}