add TopK and safetySettings
This commit is contained in:
parent
1d8d97c7a1
commit
c15250f28b
|
|
@ -1,7 +1,8 @@
|
||||||
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
|
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
|
||||||
import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
|
import { convertStringToArrayString, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
|
||||||
import { BaseCache } from 'langchain/schema'
|
import { BaseCache } from 'langchain/schema'
|
||||||
import { ChatGoogleGenerativeAI } from '@langchain/google-genai'
|
import { ChatGoogleGenerativeAI } from '@langchain/google-genai'
|
||||||
|
import { HarmBlockThreshold, HarmCategory } from '@google/generative-ai'
|
||||||
|
|
||||||
class GoogleGenerativeAI_ChatModels implements INode {
|
class GoogleGenerativeAI_ChatModels implements INode {
|
||||||
label: string
|
label: string
|
||||||
|
|
@ -74,6 +75,72 @@ class GoogleGenerativeAI_ChatModels implements INode {
|
||||||
step: 0.1,
|
step: 0.1,
|
||||||
optional: true,
|
optional: true,
|
||||||
additionalParams: true
|
additionalParams: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'topK',
|
||||||
|
name: 'topK',
|
||||||
|
type: 'number',
|
||||||
|
step: 0.1,
|
||||||
|
optional: true,
|
||||||
|
additionalParams: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Harm Category',
|
||||||
|
name: 'harmCategory',
|
||||||
|
type: 'multiOptions',
|
||||||
|
description:
|
||||||
|
'Refer to <a target="_blank" href="https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/configure-safety-attributes#gemini-TASK-samples-go">official guide</a> on how to use Harm Category',
|
||||||
|
options: [
|
||||||
|
{
|
||||||
|
label: 'Dangerous',
|
||||||
|
name: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Harassment',
|
||||||
|
name: HarmCategory.HARM_CATEGORY_HARASSMENT
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Hate Speech',
|
||||||
|
name: HarmCategory.HARM_CATEGORY_HATE_SPEECH
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Sexually Explicit',
|
||||||
|
name: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
|
||||||
|
}
|
||||||
|
],
|
||||||
|
optional: true,
|
||||||
|
additionalParams: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Harm Block Threshold',
|
||||||
|
name: 'harmBlockThreshold',
|
||||||
|
type: 'multiOptions',
|
||||||
|
description:
|
||||||
|
'Refer to <a target="_blank" href="https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/configure-safety-attributes#gemini-TASK-samples-go">official guide</a> on how to use Harm Block Threshold',
|
||||||
|
options: [
|
||||||
|
{
|
||||||
|
label: 'Low and Above',
|
||||||
|
name: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Medium and Above',
|
||||||
|
name: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'None',
|
||||||
|
name: HarmBlockThreshold.BLOCK_NONE
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Only High',
|
||||||
|
name: HarmBlockThreshold.BLOCK_ONLY_HIGH
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Threshold Unspecified',
|
||||||
|
name: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED
|
||||||
|
}
|
||||||
|
],
|
||||||
|
optional: true,
|
||||||
|
additionalParams: true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
@ -86,6 +153,9 @@ class GoogleGenerativeAI_ChatModels implements INode {
|
||||||
const modelName = nodeData.inputs?.modelName as string
|
const modelName = nodeData.inputs?.modelName as string
|
||||||
const maxOutputTokens = nodeData.inputs?.maxOutputTokens as string
|
const maxOutputTokens = nodeData.inputs?.maxOutputTokens as string
|
||||||
const topP = nodeData.inputs?.topP as string
|
const topP = nodeData.inputs?.topP as string
|
||||||
|
const topK = nodeData.inputs?.topK as string
|
||||||
|
const harmCategory = nodeData.inputs?.harmCategory as string
|
||||||
|
const harmBlockThreshold = nodeData.inputs?.harmBlockThreshold as string
|
||||||
const cache = nodeData.inputs?.cache as BaseCache
|
const cache = nodeData.inputs?.cache as BaseCache
|
||||||
|
|
||||||
const obj = {
|
const obj = {
|
||||||
|
|
@ -98,8 +168,23 @@ class GoogleGenerativeAI_ChatModels implements INode {
|
||||||
|
|
||||||
const model = new ChatGoogleGenerativeAI(obj)
|
const model = new ChatGoogleGenerativeAI(obj)
|
||||||
if (topP) model.topP = parseFloat(topP)
|
if (topP) model.topP = parseFloat(topP)
|
||||||
|
if (topK) model.topP = parseFloat(topK)
|
||||||
if (cache) model.cache = cache
|
if (cache) model.cache = cache
|
||||||
if (temperature) model.temperature = parseFloat(temperature)
|
if (temperature) model.temperature = parseFloat(temperature)
|
||||||
|
|
||||||
|
// safetySettings
|
||||||
|
let harmCategories: string[] = convertStringToArrayString(harmCategory)
|
||||||
|
let harmBlockThresholds: string[] = convertStringToArrayString(harmBlockThreshold)
|
||||||
|
if (harmCategories.length != harmBlockThresholds.length)
|
||||||
|
throw new Error(`Harm Category & Harm Block Threshold are not the same length`)
|
||||||
|
const safetySettings = harmCategories.map((value, index) => {
|
||||||
|
return {
|
||||||
|
category: value,
|
||||||
|
threshold: harmBlockThresholds[index]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if (safetySettings) model.safetySettings = safetySettings
|
||||||
|
|
||||||
return model
|
return model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -673,3 +673,20 @@ export const convertBaseMessagetoIMessage = (messages: BaseMessage[]): IMessage[
|
||||||
}
|
}
|
||||||
return formatmessages
|
return formatmessages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert String to Array String
|
||||||
|
* @param {string} inputString
|
||||||
|
* @returns {string[]}
|
||||||
|
*/
|
||||||
|
export const convertStringToArrayString = (inputString: string): string[] => {
|
||||||
|
let ArrayString: string[] = []
|
||||||
|
if (inputString) {
|
||||||
|
try {
|
||||||
|
ArrayString = JSON.parse(inputString)
|
||||||
|
} catch (e) {
|
||||||
|
ArrayString = []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ArrayString
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue