diff --git a/packages/server/src/Interface.ts b/packages/server/src/Interface.ts index 08906d44c..d228d9376 100644 --- a/packages/server/src/Interface.ts +++ b/packages/server/src/Interface.ts @@ -95,6 +95,10 @@ export interface INodeQueue { depth: number } +export interface IDepthQueue { + [key: string]: number +} + export interface IMessage { message: string type: MessageType diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 777d5fd17..cec71051b 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -5,7 +5,7 @@ import http from 'http' import * as fs from 'fs' import { IChatFlow, IncomingInput, IReactFlowNode, IReactFlowObject } from './Interface' -import { getNodeModulesPackagePath, getStartingNode, buildLangchain, getEndingNode, constructGraphs } from './utils' +import { getNodeModulesPackagePath, getStartingNodes, buildLangchain, getEndingNode, constructGraphs } from './utils' import { cloneDeep } from 'lodash' import { getDataSource } from './DataSource' import { NodesPool } from './NodesPool' @@ -209,13 +209,26 @@ export class App { const flowData = chatflow.flowData const parsedFlowData: IReactFlowObject = JSON.parse(flowData) + + /*** Get Ending Node with Directed Graph ***/ const { graph, nodeDependencies } = constructGraphs(parsedFlowData.nodes, parsedFlowData.edges) + const directedGraph = graph + const endingNodeId = getEndingNode(nodeDependencies, directedGraph) + if (!endingNodeId) return res.status(500).send(`Ending node must be either a Chain or Agent`) - const startingNodeIds = getStartingNode(nodeDependencies) - const endingNodeId = getEndingNode(nodeDependencies, graph) - if (!endingNodeId) return res.status(500).send(`Ending node must be either Chain or Agent`) + /*** Get Starting Nodes with Non-Directed Graph ***/ + const constructedObj = constructGraphs(parsedFlowData.nodes, parsedFlowData.edges, true) + const nonDirectedGraph = constructedObj.graph + const { startingNodeIds, depthQueue } = getStartingNodes(nonDirectedGraph, endingNodeId) - const reactFlowNodes = await buildLangchain(startingNodeIds, parsedFlowData.nodes, graph, this.nodesPool.componentNodes) + /*** BFS to traverse from Starting Nodes to Ending Node ***/ + const reactFlowNodes = await buildLangchain( + startingNodeIds, + parsedFlowData.nodes, + graph, + depthQueue, + this.nodesPool.componentNodes + ) const nodeToExecute = reactFlowNodes.find((node: IReactFlowNode) => node.id === endingNodeId) if (!nodeToExecute) return res.status(404).send(`Node ${endingNodeId} not found`) diff --git a/packages/server/src/utils/index.ts b/packages/server/src/utils/index.ts index e8ac8a3aa..fce1e02e6 100644 --- a/packages/server/src/utils/index.ts +++ b/packages/server/src/utils/index.ts @@ -2,6 +2,7 @@ import path from 'path' import fs from 'fs' import { IComponentNodes, + IDepthQueue, IExploredNode, INodeDependencies, INodeDirectedGraph, @@ -54,11 +55,12 @@ export const getNodeModulesPackagePath = (packageName: string): string => { } /** - * Construct directed graph and node dependencies score + * Construct graph and node dependencies score * @param {IReactFlowNode[]} reactFlowNodes * @param {IReactFlowEdge[]} reactFlowEdges + * @param {boolean} isNondirected */ -export const constructGraphs = (reactFlowNodes: IReactFlowNode[], reactFlowEdges: IReactFlowEdge[]) => { +export const constructGraphs = (reactFlowNodes: IReactFlowNode[], reactFlowEdges: IReactFlowEdge[], isNondirected = false) => { const nodeDependencies = {} as INodeDependencies const graph = {} as INodeDirectedGraph @@ -77,6 +79,14 @@ export const constructGraphs = (reactFlowNodes: IReactFlowNode[], reactFlowEdges } else { graph[source] = [target] } + + if (isNondirected) { + if (Object.prototype.hasOwnProperty.call(graph, target)) { + graph[target].push(source) + } else { + graph[target] = [source] + } + } nodeDependencies[target] += 1 } @@ -84,18 +94,52 @@ export const constructGraphs = (reactFlowNodes: IReactFlowNode[], reactFlowEdges } /** - * Get starting node and check if flow is valid - * @param {INodeDependencies} nodeDependencies + * Get starting nodes and check if flow is valid + * @param {INodeDependencies} graph + * @param {string} endNodeId */ -export const getStartingNode = (nodeDependencies: INodeDependencies) => { - // Find starting node - const startingNodeIds = [] as string[] - Object.keys(nodeDependencies).forEach((nodeId) => { - if (nodeDependencies[nodeId] === 0) { - startingNodeIds.push(nodeId) +export const getStartingNodes = (graph: INodeDirectedGraph, endNodeId: string) => { + const visited = new Set() + const queue: Array<[string, number]> = [[endNodeId, 0]] + const depthQueue: IDepthQueue = { + [endNodeId]: 0 + } + + let maxDepth = 0 + let startingNodeIds: string[] = [] + + while (queue.length > 0) { + const [currentNode, depth] = queue.shift()! + + if (visited.has(currentNode)) { + continue } - }) - return startingNodeIds + + visited.add(currentNode) + + if (depth > maxDepth) { + maxDepth = depth + startingNodeIds = [currentNode] + } else if (depth === maxDepth) { + startingNodeIds.push(currentNode) + } + + for (const neighbor of graph[currentNode]) { + if (!visited.has(neighbor)) { + queue.push([neighbor, depth + 1]) + depthQueue[neighbor] = depth + 1 + } + } + } + + const depthQueueReversed: IDepthQueue = {} + for (const nodeId in depthQueue) { + if (Object.prototype.hasOwnProperty.call(depthQueue, nodeId)) { + depthQueueReversed[nodeId] = Math.abs(depthQueue[nodeId] - maxDepth) + } + } + + return { startingNodeIds, depthQueue: depthQueueReversed } } /** @@ -104,7 +148,6 @@ export const getStartingNode = (nodeDependencies: INodeDependencies) => { * @param {INodeDirectedGraph} graph */ export const getEndingNode = (nodeDependencies: INodeDependencies, graph: INodeDirectedGraph) => { - // Find ending node let endingNodeId = '' Object.keys(graph).forEach((nodeId) => { if (Object.keys(nodeDependencies).length === 1) { @@ -121,12 +164,14 @@ export const getEndingNode = (nodeDependencies: INodeDependencies, graph: INodeD * @param {string} startingNodeId * @param {IReactFlowNode[]} reactFlowNodes * @param {INodeDirectedGraph} graph + * @param {IDepthQueue} depthQueue * @param {IComponentNodes} componentNodes */ export const buildLangchain = async ( startingNodeIds: string[], reactFlowNodes: IReactFlowNode[], graph: INodeDirectedGraph, + depthQueue: IDepthQueue, componentNodes: IComponentNodes ) => { const flowNodes = cloneDeep(reactFlowNodes) @@ -166,6 +211,14 @@ export const buildLangchain = async ( const neighbourNodeIds = graph[nodeId] const nextDepth = depth + 1 + // Find other nodes that are on the same depth level + const sameDepthNodeIds = Object.keys(depthQueue).filter((key) => depthQueue[key] === nextDepth) + + for (const id of sameDepthNodeIds) { + if (neighbourNodeIds.includes(id)) continue + neighbourNodeIds.push(id) + } + for (let i = 0; i < neighbourNodeIds.length; i += 1) { const neighNodeId = neighbourNodeIds[i]