From e39fd130d0836a1b1dc6c4027552deddd854c7fe Mon Sep 17 00:00:00 2001 From: Henry Heng Date: Sat, 27 Jul 2024 13:13:16 +0100 Subject: [PATCH] Feat/add ability to specify initial state in overrideConfig (#2893) add ability to specify initial state in overrideConfig --- .../nodes/sequentialagents/Agent/Agent.ts | 22 +++++++++++ .../nodes/sequentialagents/LLMNode/LLMNode.ts | 22 +++++++++++ .../nodes/sequentialagents/State/State.ts | 37 +++++++++++++++++++ .../sequentialagents/ToolNode/ToolNode.ts | 22 +++++++++++ 4 files changed, 103 insertions(+) diff --git a/packages/components/nodes/sequentialagents/Agent/Agent.ts b/packages/components/nodes/sequentialagents/Agent/Agent.ts index 79ce87065..eea8b1401 100644 --- a/packages/components/nodes/sequentialagents/Agent/Agent.ts +++ b/packages/components/nodes/sequentialagents/Agent/Agent.ts @@ -754,6 +754,7 @@ const getReturnOutput = async (nodeData: INodeData, input: string, options: ICom const tabIdentifier = nodeData.inputs?.[`${TAB_IDENTIFIER}_${nodeData.id}`] as string const updateStateMemoryUI = nodeData.inputs?.updateStateMemoryUI as string const updateStateMemoryCode = nodeData.inputs?.updateStateMemoryCode as string + const updateStateMemory = nodeData.inputs?.updateStateMemory as string const selectedTab = tabIdentifier ? tabIdentifier.split(`_${nodeData.id}`)[0] : 'updateStateMemoryUI' const variables = await getVars(appDataSource, databaseEntities, nodeData) @@ -768,6 +769,27 @@ const getReturnOutput = async (nodeData: INodeData, input: string, options: ICom vars: prepareSandboxVars(variables) } + if (updateStateMemory && updateStateMemory !== 'updateStateMemoryUI' && updateStateMemory !== 'updateStateMemoryCode') { + try { + const parsedSchema = typeof updateStateMemory === 'string' ? JSON.parse(updateStateMemory) : updateStateMemory + const obj: ICommonObject = {} + for (const sch of parsedSchema) { + const key = sch.Key + if (!key) throw new Error(`Key is required`) + let value = sch.Value as string + if (value.startsWith('$flow')) { + value = customGet(flow, sch.Value.replace('$flow.', '')) + } else if (value.startsWith('$vars')) { + value = customGet(flow, sch.Value.replace('$', '')) + } + obj[key] = value + } + return obj + } catch (e) { + throw new Error(e) + } + } + if (selectedTab === 'updateStateMemoryUI' && updateStateMemoryUI) { try { const parsedSchema = typeof updateStateMemoryUI === 'string' ? JSON.parse(updateStateMemoryUI) : updateStateMemoryUI diff --git a/packages/components/nodes/sequentialagents/LLMNode/LLMNode.ts b/packages/components/nodes/sequentialagents/LLMNode/LLMNode.ts index a09bf45cc..143e4cd42 100644 --- a/packages/components/nodes/sequentialagents/LLMNode/LLMNode.ts +++ b/packages/components/nodes/sequentialagents/LLMNode/LLMNode.ts @@ -557,6 +557,7 @@ const getReturnOutput = async (nodeData: INodeData, input: string, options: ICom const tabIdentifier = nodeData.inputs?.[`${TAB_IDENTIFIER}_${nodeData.id}`] as string const updateStateMemoryUI = nodeData.inputs?.updateStateMemoryUI as string const updateStateMemoryCode = nodeData.inputs?.updateStateMemoryCode as string + const updateStateMemory = nodeData.inputs?.updateStateMemory as string const selectedTab = tabIdentifier ? tabIdentifier.split(`_${nodeData.id}`)[0] : 'updateStateMemoryUI' const variables = await getVars(appDataSource, databaseEntities, nodeData) @@ -571,6 +572,27 @@ const getReturnOutput = async (nodeData: INodeData, input: string, options: ICom vars: prepareSandboxVars(variables) } + if (updateStateMemory && updateStateMemory !== 'updateStateMemoryUI' && updateStateMemory !== 'updateStateMemoryCode') { + try { + const parsedSchema = typeof updateStateMemory === 'string' ? JSON.parse(updateStateMemory) : updateStateMemory + const obj: ICommonObject = {} + for (const sch of parsedSchema) { + const key = sch.Key + if (!key) throw new Error(`Key is required`) + let value = sch.Value as string + if (value.startsWith('$flow')) { + value = customGet(flow, sch.Value.replace('$flow.', '')) + } else if (value.startsWith('$vars')) { + value = customGet(flow, sch.Value.replace('$', '')) + } + obj[key] = value + } + return obj + } catch (e) { + throw new Error(e) + } + } + if (selectedTab === 'updateStateMemoryUI' && updateStateMemoryUI) { try { const parsedSchema = typeof updateStateMemoryUI === 'string' ? JSON.parse(updateStateMemoryUI) : updateStateMemoryUI diff --git a/packages/components/nodes/sequentialagents/State/State.ts b/packages/components/nodes/sequentialagents/State/State.ts index 71d9ea3d4..06331186e 100644 --- a/packages/components/nodes/sequentialagents/State/State.ts +++ b/packages/components/nodes/sequentialagents/State/State.ts @@ -101,6 +101,43 @@ class State_SeqAgents implements INode { const appDataSource = options.appDataSource as DataSource const databaseEntities = options.databaseEntities as IDatabaseEntity const selectedTab = tabIdentifier ? tabIdentifier.split(`_${nodeData.id}`)[0] : 'stateMemoryUI' + const stateMemory = nodeData.inputs?.stateMemory as string + + if (stateMemory && stateMemory !== 'stateMemoryUI' && stateMemory !== 'stateMemoryCode') { + try { + const parsedSchema = typeof stateMemory === 'string' ? JSON.parse(stateMemory) : stateMemory + const obj: ICommonObject = {} + for (const sch of parsedSchema) { + const key = sch.Key + if (!key) throw new Error(`Key is required`) + const type = sch.Operation + const defaultValue = sch['Default Value'] + + if (type === 'Append') { + obj[key] = { + value: (x: any, y: any) => (Array.isArray(y) ? x.concat(y) : x.concat([y])), + default: () => (defaultValue ? JSON.parse(defaultValue) : []) + } + } else { + obj[key] = { + value: (x: any, y: any) => y ?? x, + default: () => defaultValue + } + } + } + const returnOutput: ISeqAgentNode = { + id: nodeData.id, + node: obj, + name: 'state', + label: 'state', + type: 'state', + output: START + } + return returnOutput + } catch (e) { + throw new Error(e) + } + } if (!stateMemoryUI && !stateMemoryCode) { const returnOutput: ISeqAgentNode = { diff --git a/packages/components/nodes/sequentialagents/ToolNode/ToolNode.ts b/packages/components/nodes/sequentialagents/ToolNode/ToolNode.ts index 607b152b4..c498d2fb6 100644 --- a/packages/components/nodes/sequentialagents/ToolNode/ToolNode.ts +++ b/packages/components/nodes/sequentialagents/ToolNode/ToolNode.ts @@ -441,6 +441,7 @@ const getReturnOutput = async ( const tabIdentifier = nodeData.inputs?.[`${TAB_IDENTIFIER}_${nodeData.id}`] as string const updateStateMemoryUI = nodeData.inputs?.updateStateMemoryUI as string const updateStateMemoryCode = nodeData.inputs?.updateStateMemoryCode as string + const updateStateMemory = nodeData.inputs?.updateStateMemory as string const selectedTab = tabIdentifier ? tabIdentifier.split(`_${nodeData.id}`)[0] : 'updateStateMemoryUI' const variables = await getVars(appDataSource, databaseEntities, nodeData) @@ -464,6 +465,27 @@ const getReturnOutput = async ( vars: prepareSandboxVars(variables) } + if (updateStateMemory && updateStateMemory !== 'updateStateMemoryUI' && updateStateMemory !== 'updateStateMemoryCode') { + try { + const parsedSchema = typeof updateStateMemory === 'string' ? JSON.parse(updateStateMemory) : updateStateMemory + const obj: ICommonObject = {} + for (const sch of parsedSchema) { + const key = sch.Key + if (!key) throw new Error(`Key is required`) + let value = sch.Value as string + if (value.startsWith('$flow')) { + value = customGet(flow, sch.Value.replace('$flow.', '')) + } else if (value.startsWith('$vars')) { + value = customGet(flow, sch.Value.replace('$', '')) + } + obj[key] = value + } + return obj + } catch (e) { + throw new Error(e) + } + } + if (selectedTab === 'updateStateMemoryUI' && updateStateMemoryUI) { try { const parsedSchema = typeof updateStateMemoryUI === 'string' ? JSON.parse(updateStateMemoryUI) : updateStateMemoryUI