diff --git a/packages/components/nodes/chains/ConversationChain/ConversationChain.ts b/packages/components/nodes/chains/ConversationChain/ConversationChain.ts index 2233ebf94..4e39ae6db 100644 --- a/packages/components/nodes/chains/ConversationChain/ConversationChain.ts +++ b/packages/components/nodes/chains/ConversationChain/ConversationChain.ts @@ -1,6 +1,6 @@ import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface' import { ConversationChain } from 'langchain/chains' -import { getBaseClasses } from '../../../src/utils' +import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate } from 'langchain/prompts' import { BufferMemory, ChatMessageHistory } from 'langchain/memory' import { BaseChatModel } from 'langchain/chat_models/base' @@ -90,8 +90,14 @@ class ConversationChain_Chains implements INode { chain.memory = memory } - const res = await chain.call({ input }) - return res?.response + if (options.socketIO && options.socketIOClientId) { + const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + const res = await chain.call({ input }, [handler]) + return res?.response + } else { + const res = await chain.call({ input }) + return res?.text + } } } diff --git a/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts b/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts index 7e37f9131..659e1e2e6 100644 --- a/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts +++ b/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts @@ -1,6 +1,6 @@ import { BaseLanguageModel } from 'langchain/base_language' import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface' -import { getBaseClasses } from '../../../src/utils' +import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { ConversationalRetrievalQAChain } from 'langchain/chains' import { BaseRetriever } from 'langchain/schema' @@ -74,6 +74,12 @@ class ConversationalRetrievalQAChain_Chains implements INode { async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const chain = nodeData.instance as ConversationalRetrievalQAChain + let model = nodeData.inputs?.model + + // Temporary fix: https://github.com/hwchase17/langchainjs/issues/754 + model.streaming = false + chain.questionGeneratorChain.llm = model + let chatHistory = '' if (options && options.chatHistory) { @@ -90,9 +96,14 @@ class ConversationalRetrievalQAChain_Chains implements INode { chat_history: chatHistory ? chatHistory : [] } - const res = await chain.call(obj) - - return res?.text + if (options.socketIO && options.socketIOClientId) { + const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + const res = await chain.call(obj, [handler]) + return res?.text + } else { + const res = await chain.call(obj) + return res?.text + } } } diff --git a/packages/components/nodes/chains/LLMChain/LLMChain.ts b/packages/components/nodes/chains/LLMChain/LLMChain.ts index b178e28df..9cd08d353 100644 --- a/packages/components/nodes/chains/LLMChain/LLMChain.ts +++ b/packages/components/nodes/chains/LLMChain/LLMChain.ts @@ -1,5 +1,5 @@ import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/Interface' -import { getBaseClasses } from '../../../src/utils' +import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { LLMChain } from 'langchain/chains' import { BaseLanguageModel } from 'langchain/base_language' @@ -76,12 +76,14 @@ class LLMChain_Chains implements INode { } } - async run(nodeData: INodeData, input: string): Promise { + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const inputVariables = nodeData.instance.prompt.inputVariables as string[] // ["product"] const chain = nodeData.instance as LLMChain const promptValues = nodeData.inputs?.prompt.promptValues as ICommonObject - const res = await runPrediction(inputVariables, chain, input, promptValues) + const res = options.socketIO + ? await runPrediction(inputVariables, chain, input, promptValues, true, options.socketIO, options.socketIOClientId) + : await runPrediction(inputVariables, chain, input, promptValues) // eslint-disable-next-line no-console console.log('\x1b[93m\x1b[1m\n*****FINAL RESULT*****\n\x1b[0m\x1b[0m') // eslint-disable-next-line no-console @@ -90,10 +92,24 @@ class LLMChain_Chains implements INode { } } -const runPrediction = async (inputVariables: string[], chain: LLMChain, input: string, promptValues: ICommonObject) => { +const runPrediction = async ( + inputVariables: string[], + chain: LLMChain, + input: string, + promptValues: ICommonObject, + isStreaming?: boolean, + socketIO?: any, + socketIOClientId = '' +) => { if (inputVariables.length === 1) { - const res = await chain.run(input) - return res + if (isStreaming) { + const handler = new CustomChainHandler(socketIO, socketIOClientId) + const res = await chain.run(input, [handler]) + return res + } else { + const res = await chain.run(input) + return res + } } else if (inputVariables.length > 1) { let seen: string[] = [] @@ -109,8 +125,14 @@ const runPrediction = async (inputVariables: string[], chain: LLMChain, input: s const options = { ...promptValues } - const res = await chain.call(options) - return res?.text + if (isStreaming) { + const handler = new CustomChainHandler(socketIO, socketIOClientId) + const res = await chain.call(options, [handler]) + return res?.text + } else { + const res = await chain.call(options) + return res?.text + } } else if (seen.length === 1) { // If one inputVariable is not specify, use input (user's question) as value const lastValue = seen.pop() @@ -119,14 +141,26 @@ const runPrediction = async (inputVariables: string[], chain: LLMChain, input: s ...promptValues, [lastValue]: input } - const res = await chain.call(options) - return res?.text + if (isStreaming) { + const handler = new CustomChainHandler(socketIO, socketIOClientId) + const res = await chain.call(options, [handler]) + return res?.text + } else { + const res = await chain.call(options) + return res?.text + } } else { throw new Error(`Please provide Prompt Values for: ${seen.join(', ')}`) } } else { - const res = await chain.run(input) - return res + if (isStreaming) { + const handler = new CustomChainHandler(socketIO, socketIOClientId) + const res = await chain.run(input, [handler]) + return res + } else { + const res = await chain.run(input) + return res + } } } diff --git a/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts b/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts index c002b6848..97fa51a1c 100644 --- a/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts +++ b/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts @@ -1,7 +1,7 @@ -import { INode, INodeData, INodeParams } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' import { RetrievalQAChain } from 'langchain/chains' import { BaseRetriever } from 'langchain/schema' -import { getBaseClasses } from '../../../src/utils' +import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { BaseLanguageModel } from 'langchain/base_language' class RetrievalQAChain_Chains implements INode { @@ -44,13 +44,20 @@ class RetrievalQAChain_Chains implements INode { return chain } - async run(nodeData: INodeData, input: string): Promise { + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const chain = nodeData.instance as RetrievalQAChain const obj = { query: input } - const res = await chain.call(obj) - return res?.text + + if (options.socketIO && options.socketIOClientId) { + const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + const res = await chain.call(obj, [handler]) + return res?.text + } else { + const res = await chain.call(obj) + return res?.text + } } } diff --git a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts index 7ea10d941..27a245e1a 100644 --- a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts +++ b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts @@ -1,6 +1,6 @@ -import { INode, INodeData, INodeParams } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' import { SqlDatabaseChain, SqlDatabaseChainInput } from 'langchain/chains' -import { getBaseClasses } from '../../../src/utils' +import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { DataSource } from 'typeorm' import { SqlDatabase } from 'langchain/sql_db' import { BaseLanguageModel } from 'langchain/base_language' @@ -59,14 +59,20 @@ class SqlDatabaseChain_Chains implements INode { return chain } - async run(nodeData: INodeData, input: string): Promise { + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const databaseType = nodeData.inputs?.database as 'sqlite' const model = nodeData.inputs?.model as BaseLanguageModel const dbFilePath = nodeData.inputs?.dbFilePath const chain = await getSQLDBChain(databaseType, dbFilePath, model) - const res = await chain.run(input) - return res + if (options.socketIO && options.socketIOClientId) { + const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + const res = await chain.run(input, [handler]) + return res + } else { + const res = await chain.run(input) + return res + } } } diff --git a/packages/components/nodes/chains/VectorDBQAChain/VectorDBQAChain.ts b/packages/components/nodes/chains/VectorDBQAChain/VectorDBQAChain.ts index 37d388b4a..f752d60cc 100644 --- a/packages/components/nodes/chains/VectorDBQAChain/VectorDBQAChain.ts +++ b/packages/components/nodes/chains/VectorDBQAChain/VectorDBQAChain.ts @@ -1,5 +1,5 @@ -import { INode, INodeData, INodeParams } from '../../../src/Interface' -import { getBaseClasses } from '../../../src/utils' +import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { VectorDBQAChain } from 'langchain/chains' import { BaseLanguageModel } from 'langchain/base_language' import { VectorStore } from 'langchain/vectorstores' @@ -44,13 +44,20 @@ class VectorDBQAChain_Chains implements INode { return chain } - async run(nodeData: INodeData, input: string): Promise { + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const chain = nodeData.instance as VectorDBQAChain const obj = { query: input } - const res = await chain.call(obj) - return res?.text + + if (options.socketIO && options.socketIOClientId) { + const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + const res = await chain.call(obj, [handler]) + return res?.text + } else { + const res = await chain.call(obj) + return res?.text + } } } diff --git a/packages/components/nodes/chatmodels/AzureChatOpenAI/AzureChatOpenAI.ts b/packages/components/nodes/chatmodels/AzureChatOpenAI/AzureChatOpenAI.ts index 1d2fabc76..7857bfdf4 100644 --- a/packages/components/nodes/chatmodels/AzureChatOpenAI/AzureChatOpenAI.ts +++ b/packages/components/nodes/chatmodels/AzureChatOpenAI/AzureChatOpenAI.ts @@ -121,6 +121,7 @@ class AzureChatOpenAI_ChatModels implements INode { const frequencyPenalty = nodeData.inputs?.frequencyPenalty as string const presencePenalty = nodeData.inputs?.presencePenalty as string const timeout = nodeData.inputs?.timeout as string + const streaming = nodeData.inputs?.streaming as boolean const obj: Partial & Partial = { temperature: parseInt(temperature, 10), @@ -128,7 +129,8 @@ class AzureChatOpenAI_ChatModels implements INode { azureOpenAIApiKey, azureOpenAIApiInstanceName, azureOpenAIApiDeploymentName, - azureOpenAIApiVersion + azureOpenAIApiVersion, + streaming: streaming ?? true } if (maxTokens) obj.maxTokens = parseInt(maxTokens, 10) diff --git a/packages/components/nodes/chatmodels/ChatAnthropic/ChatAnthropic.ts b/packages/components/nodes/chatmodels/ChatAnthropic/ChatAnthropic.ts index b13339ad4..708849e53 100644 --- a/packages/components/nodes/chatmodels/ChatAnthropic/ChatAnthropic.ts +++ b/packages/components/nodes/chatmodels/ChatAnthropic/ChatAnthropic.ts @@ -117,11 +117,13 @@ class ChatAnthropic_ChatModels implements INode { const maxTokensToSample = nodeData.inputs?.maxTokensToSample as string const topP = nodeData.inputs?.topP as string const topK = nodeData.inputs?.topK as string + const streaming = nodeData.inputs?.streaming as boolean const obj: Partial & { anthropicApiKey?: string } = { temperature: parseInt(temperature, 10), modelName, - anthropicApiKey + anthropicApiKey, + streaming: streaming ?? true } if (maxTokensToSample) obj.maxTokensToSample = parseInt(maxTokensToSample, 10) diff --git a/packages/components/nodes/chatmodels/ChatOpenAI/ChatOpenAI.ts b/packages/components/nodes/chatmodels/ChatOpenAI/ChatOpenAI.ts index 5d608c5e2..7d2098ec3 100644 --- a/packages/components/nodes/chatmodels/ChatOpenAI/ChatOpenAI.ts +++ b/packages/components/nodes/chatmodels/ChatOpenAI/ChatOpenAI.ts @@ -109,11 +109,13 @@ class ChatOpenAI_ChatModels implements INode { const frequencyPenalty = nodeData.inputs?.frequencyPenalty as string const presencePenalty = nodeData.inputs?.presencePenalty as string const timeout = nodeData.inputs?.timeout as string + const streaming = nodeData.inputs?.streaming as boolean const obj: Partial & { openAIApiKey?: string } = { temperature: parseInt(temperature, 10), modelName, - openAIApiKey + openAIApiKey, + streaming: streaming ?? true } if (maxTokens) obj.maxTokens = parseInt(maxTokens, 10) diff --git a/packages/components/nodes/llms/Azure OpenAI/AzureOpenAI.ts b/packages/components/nodes/llms/Azure OpenAI/AzureOpenAI.ts index b5d7d1e03..c19aa83aa 100644 --- a/packages/components/nodes/llms/Azure OpenAI/AzureOpenAI.ts +++ b/packages/components/nodes/llms/Azure OpenAI/AzureOpenAI.ts @@ -176,6 +176,7 @@ class AzureOpenAI_LLMs implements INode { const presencePenalty = nodeData.inputs?.presencePenalty as string const timeout = nodeData.inputs?.timeout as string const bestOf = nodeData.inputs?.bestOf as string + const streaming = nodeData.inputs?.streaming as boolean const obj: Partial & Partial = { temperature: parseInt(temperature, 10), @@ -183,7 +184,8 @@ class AzureOpenAI_LLMs implements INode { azureOpenAIApiKey, azureOpenAIApiInstanceName, azureOpenAIApiDeploymentName, - azureOpenAIApiVersion + azureOpenAIApiVersion, + streaming: streaming ?? true } if (maxTokens) obj.maxTokens = parseInt(maxTokens, 10) diff --git a/packages/components/nodes/llms/OpenAI/OpenAI.ts b/packages/components/nodes/llms/OpenAI/OpenAI.ts index af44965e3..48b1bc841 100644 --- a/packages/components/nodes/llms/OpenAI/OpenAI.ts +++ b/packages/components/nodes/llms/OpenAI/OpenAI.ts @@ -121,11 +121,13 @@ class OpenAI_LLMs implements INode { const timeout = nodeData.inputs?.timeout as string const batchSize = nodeData.inputs?.batchSize as string const bestOf = nodeData.inputs?.bestOf as string + const streaming = nodeData.inputs?.streaming as boolean const obj: Partial & { openAIApiKey?: string } = { temperature: parseInt(temperature, 10), modelName, - openAIApiKey + openAIApiKey, + streaming: streaming ?? true } if (maxTokens) obj.maxTokens = parseInt(maxTokens, 10) diff --git a/packages/components/src/utils.ts b/packages/components/src/utils.ts index 10091d602..68c098fdd 100644 --- a/packages/components/src/utils.ts +++ b/packages/components/src/utils.ts @@ -2,6 +2,8 @@ import axios from 'axios' import { load } from 'cheerio' import * as fs from 'fs' import * as path from 'path' +import { BaseCallbackHandler } from 'langchain/callbacks' +import { Server } from 'socket.io' export const numberOrExpressionRegex = '^(\\d+\\.?\\d*|{{.*}})$' //return true if string consists only numbers OR expression {{}} export const notEmptyRegex = '(.|\\s)*\\S(.|\\s)*' //return true if string is not empty or blank @@ -152,6 +154,12 @@ export const getInputVariables = (paramValue: string): string[] => { return inputVariables } +/** + * Crawl all available urls given a domain url and limit + * @param {string} url + * @param {number} limit + * @returns {string[]} + */ export const getAvailableURLs = async (url: string, limit: number) => { try { const availableUrls: string[] = [] @@ -190,3 +198,31 @@ export const getAvailableURLs = async (url: string, limit: number) => { throw new Error(`getAvailableURLs: ${err?.message}`) } } + +/** + * Custom chain handler class + */ +export class CustomChainHandler extends BaseCallbackHandler { + name = 'custom_chain_handler' + isLLMStarted = false + socketIO: Server + socketIOClientId = '' + + constructor(socketIO: Server, socketIOClientId: string) { + super() + this.socketIO = socketIO + this.socketIOClientId = socketIOClientId + } + + handleLLMNewToken(token: string) { + if (!this.isLLMStarted) { + this.isLLMStarted = true + this.socketIO.to(this.socketIOClientId).emit('start', token) + } + this.socketIO.to(this.socketIOClientId).emit('token', token) + } + + handleLLMEnd() { + this.socketIO.to(this.socketIOClientId).emit('end') + } +} diff --git a/packages/server/package.json b/packages/server/package.json index b8777a21d..a230f94a6 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -57,6 +57,7 @@ "moment-timezone": "^0.5.34", "multer": "^1.4.5-lts.1", "reflect-metadata": "^0.1.13", + "socket.io": "^4.6.1", "sqlite3": "^5.1.6", "typeorm": "^0.3.6" }, diff --git a/packages/server/src/Interface.ts b/packages/server/src/Interface.ts index 30f9fb292..0dede0248 100644 --- a/packages/server/src/Interface.ts +++ b/packages/server/src/Interface.ts @@ -115,6 +115,7 @@ export interface IncomingInput { question: string history: IMessage[] overrideConfig?: ICommonObject + socketIOClientId?: string } export interface IActiveChatflows { diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 1ce661178..59dfe3cfb 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -5,6 +5,7 @@ import cors from 'cors' import http from 'http' import * as fs from 'fs' import basicAuth from 'express-basic-auth' +import { Server } from 'socket.io' import { IChatFlow, @@ -32,7 +33,8 @@ import { mapMimeTypeToInputField, findAvailableConfigs, isSameOverrideConfig, - replaceAllAPIKeys + replaceAllAPIKeys, + isFlowValidForStream } from './utils' import { cloneDeep } from 'lodash' import { getDataSource } from './DataSource' @@ -73,7 +75,7 @@ export class App { }) } - async config() { + async config(socketIO?: Server) { // Limit is needed to allow sending/receiving base64 encoded string this.app.use(express.json({ limit: '50mb' })) this.app.use(express.urlencoded({ limit: '50mb', extended: true })) @@ -200,6 +202,30 @@ export class App { return res.json(results) }) + // Check if chatflow valid for streaming + this.app.get('/api/v1/chatflows-streaming/:id', async (req: Request, res: Response) => { + const chatflow = await this.AppDataSource.getRepository(ChatFlow).findOneBy({ + id: req.params.id + }) + if (!chatflow) return res.status(404).send(`Chatflow ${req.params.id} not found`) + + /*** Get Ending Node with Directed Graph ***/ + const flowData = chatflow.flowData + const parsedFlowData: IReactFlowObject = JSON.parse(flowData) + const nodes = parsedFlowData.nodes + const edges = parsedFlowData.edges + const { graph, nodeDependencies } = constructGraphs(nodes, edges) + const endingNodeId = getEndingNode(nodeDependencies, graph) + if (!endingNodeId) return res.status(500).send(`Ending node must be either a Chain or Agent`) + const endingNodeData = nodes.find((nd) => nd.id === endingNodeId)?.data + if (!endingNodeData) return res.status(500).send(`Ending node must be either a Chain or Agent`) + + const obj = { + isStreaming: isFlowValidForStream(nodes, endingNodeData) + } + return res.json(obj) + }) + // ---------------------------------------- // ChatMessage // ---------------------------------------- @@ -303,12 +329,12 @@ export class App { // Send input message and get prediction result (External) this.app.post('/api/v1/prediction/:id', upload.array('files'), async (req: Request, res: Response) => { - await this.processPrediction(req, res) + await this.processPrediction(req, res, socketIO) }) // Send input message and get prediction result (Internal) this.app.post('/api/v1/internal-prediction/:id', async (req: Request, res: Response) => { - await this.processPrediction(req, res, true) + await this.processPrediction(req, res, socketIO, true) }) // ---------------------------------------- @@ -464,9 +490,10 @@ export class App { * Process Prediction * @param {Request} req * @param {Response} res + * @param {Server} socketIO * @param {boolean} isInternal */ - async processPrediction(req: Request, res: Response, isInternal = false) { + async processPrediction(req: Request, res: Response, socketIO?: Server, isInternal = false) { try { const chatflowid = req.params.id let incomingInput: IncomingInput = req.body @@ -482,6 +509,8 @@ export class App { await this.validateKey(req, res, chatflow) } + let isStreamValid = false + const files = (req.files as any[]) || [] if (files.length) { @@ -542,15 +571,16 @@ export class App { } } } else { + /*** Get chatflows and prepare data ***/ + const flowData = chatflow.flowData + const parsedFlowData: IReactFlowObject = JSON.parse(flowData) + const nodes = parsedFlowData.nodes + const edges = parsedFlowData.edges + if (isRebuildNeeded()) { nodeToExecuteData = this.chatflowPool.activeChatflows[chatflowid].endingNodeData + isStreamValid = isFlowValidForStream(nodes, nodeToExecuteData) } else { - /*** Get chatflows and prepare data ***/ - const flowData = chatflow.flowData - const parsedFlowData: IReactFlowObject = JSON.parse(flowData) - const nodes = parsedFlowData.nodes - const edges = parsedFlowData.edges - /*** Get Ending Node with Directed Graph ***/ const { graph, nodeDependencies } = constructGraphs(nodes, edges) const directedGraph = graph @@ -572,6 +602,8 @@ export class App { ) } + isStreamValid = isFlowValidForStream(nodes, endingNodeData) + /*** Get Starting Nodes with Non-Directed Graph ***/ const constructedObj = constructGraphs(nodes, edges, true) const nonDirectedGraph = constructedObj.graph @@ -602,7 +634,13 @@ export class App { const nodeModule = await import(nodeInstanceFilePath) const nodeInstance = new nodeModule.nodeClass() - const result = await nodeInstance.run(nodeToExecuteData, incomingInput.question, { chatHistory: incomingInput.history }) + const result = isStreamValid + ? await nodeInstance.run(nodeToExecuteData, incomingInput.question, { + chatHistory: incomingInput.history, + socketIO, + socketIOClientId: incomingInput.socketIOClientId + }) + : await nodeInstance.run(nodeToExecuteData, incomingInput.question, { chatHistory: incomingInput.history }) return res.json(result) } @@ -629,8 +667,14 @@ export async function start(): Promise { const port = parseInt(process.env.PORT || '', 10) || 3000 const server = http.createServer(serverApp.app) + const io = new Server(server, { + cors: { + origin: '*' + } + }) + await serverApp.initDatabase() - await serverApp.config() + await serverApp.config(io) server.listen(port, () => { console.info(`⚡️[server]: Flowise Server is listening at ${port}`) diff --git a/packages/server/src/utils/index.ts b/packages/server/src/utils/index.ts index 787612844..982a82d0f 100644 --- a/packages/server/src/utils/index.ts +++ b/packages/server/src/utils/index.ts @@ -610,3 +610,28 @@ export const findAvailableConfigs = (reactFlowNodes: IReactFlowNode[]) => { return configs } + +/** + * Check to see if flow valid for stream + * @param {IReactFlowNode[]} reactFlowNodes + * @param {INodeData} endingNodeData + * @returns {boolean} + */ +export const isFlowValidForStream = (reactFlowNodes: IReactFlowNode[], endingNodeData: INodeData) => { + const streamAvailableLLMs = { + 'Chat Models': ['azureChatOpenAI', 'chatOpenAI', 'chatAnthropic'], + LLMs: ['azureOpenAI', 'openAI'] + } + + let isChatOrLLMsExist = false + for (const flowNode of reactFlowNodes) { + const data = flowNode.data + if (data.category === 'Chat Models' || data.category === 'LLMs') { + isChatOrLLMsExist = true + const validLLMs = streamAvailableLLMs[data.category] + if (!validLLMs.includes(data.name)) return false + } + } + + return isChatOrLLMsExist && endingNodeData.category === 'Chains' +} diff --git a/packages/ui/package.json b/packages/ui/package.json index b367b1830..fc2961fc4 100644 --- a/packages/ui/package.json +++ b/packages/ui/package.json @@ -36,8 +36,13 @@ "react-router": "~6.3.0", "react-router-dom": "~6.3.0", "react-simple-code-editor": "^0.11.2", + "react-syntax-highlighter": "^15.5.0", "reactflow": "^11.5.6", "redux": "^4.0.5", + "rehype-mathjax": "^4.0.2", + "remark-gfm": "^3.0.1", + "remark-math": "^5.1.1", + "socket.io-client": "^4.6.1", "yup": "^0.32.9" }, "scripts": { diff --git a/packages/ui/src/api/chatflows.js b/packages/ui/src/api/chatflows.js index eae010eda..1cd1ebb09 100644 --- a/packages/ui/src/api/chatflows.js +++ b/packages/ui/src/api/chatflows.js @@ -10,10 +10,13 @@ const updateChatflow = (id, body) => client.put(`/chatflows/${id}`, body) const deleteChatflow = (id) => client.delete(`/chatflows/${id}`) +const getIsChatflowStreaming = (id) => client.get(`/chatflows-streaming/${id}`) + export default { getAllChatflows, getSpecificChatflow, createNewChatflow, updateChatflow, - deleteChatflow + deleteChatflow, + getIsChatflowStreaming } diff --git a/packages/ui/src/themes/compStyleOverride.js b/packages/ui/src/themes/compStyleOverride.js index eb6f6de94..b7ebc8b21 100644 --- a/packages/ui/src/themes/compStyleOverride.js +++ b/packages/ui/src/themes/compStyleOverride.js @@ -1,6 +1,39 @@ export default function componentStyleOverrides(theme) { const bgColor = theme.colors?.grey50 return { + MuiCssBaseline: { + styleOverrides: { + body: { + scrollbarWidth: 'thin', + scrollbarColor: theme?.customization?.isDarkMode + ? `${theme.colors?.grey500} ${theme.colors?.darkPrimaryMain}` + : `${theme.colors?.grey300} ${theme.paper}`, + '&::-webkit-scrollbar, & *::-webkit-scrollbar': { + width: 12, + height: 12, + backgroundColor: theme?.customization?.isDarkMode ? theme.colors?.darkPrimaryMain : theme.paper + }, + '&::-webkit-scrollbar-thumb, & *::-webkit-scrollbar-thumb': { + borderRadius: 8, + backgroundColor: theme?.customization?.isDarkMode ? theme.colors?.grey500 : theme.colors?.grey300, + minHeight: 24, + border: `3px solid ${theme?.customization?.isDarkMode ? theme.colors?.darkPrimaryMain : theme.paper}` + }, + '&::-webkit-scrollbar-thumb:focus, & *::-webkit-scrollbar-thumb:focus': { + backgroundColor: theme?.customization?.isDarkMode ? theme.colors?.darkPrimary200 : theme.colors?.grey500 + }, + '&::-webkit-scrollbar-thumb:active, & *::-webkit-scrollbar-thumb:active': { + backgroundColor: theme?.customization?.isDarkMode ? theme.colors?.darkPrimary200 : theme.colors?.grey500 + }, + '&::-webkit-scrollbar-thumb:hover, & *::-webkit-scrollbar-thumb:hover': { + backgroundColor: theme?.customization?.isDarkMode ? theme.colors?.darkPrimary200 : theme.colors?.grey500 + }, + '&::-webkit-scrollbar-corner, & *::-webkit-scrollbar-corner': { + backgroundColor: theme?.customization?.isDarkMode ? theme.colors?.darkPrimaryMain : theme.paper + } + } + } + }, MuiButton: { styleOverrides: { root: { diff --git a/packages/ui/src/themes/palette.js b/packages/ui/src/themes/palette.js index a4a5104dd..9e7b7620e 100644 --- a/packages/ui/src/themes/palette.js +++ b/packages/ui/src/themes/palette.js @@ -7,7 +7,8 @@ export default function themePalette(theme) { return { mode: theme?.customization?.navType, common: { - black: theme.colors?.darkPaper + black: theme.colors?.darkPaper, + dark: theme.colors?.darkPrimaryMain }, primary: { light: theme.customization.isDarkMode ? theme.colors?.darkPrimaryLight : theme.colors?.primaryLight, diff --git a/packages/ui/src/ui-component/dialog/ConfirmDialog.js b/packages/ui/src/ui-component/dialog/ConfirmDialog.js index 8176ecd1f..6f8712f5b 100644 --- a/packages/ui/src/ui-component/dialog/ConfirmDialog.js +++ b/packages/ui/src/ui-component/dialog/ConfirmDialog.js @@ -1,5 +1,5 @@ import { createPortal } from 'react-dom' -import { Button, Dialog, DialogActions, DialogContent, DialogContentText, DialogTitle } from '@mui/material' +import { Button, Dialog, DialogActions, DialogContent, DialogTitle } from '@mui/material' import useConfirm from 'hooks/useConfirm' import { StyledButton } from 'ui-component/button/StyledButton' @@ -20,9 +20,7 @@ const ConfirmDialog = () => { {confirmState.title} - - {confirmState.description} - + {confirmState.description} diff --git a/packages/ui/src/ui-component/markdown/CodeBlock.js b/packages/ui/src/ui-component/markdown/CodeBlock.js new file mode 100644 index 000000000..77caa346c --- /dev/null +++ b/packages/ui/src/ui-component/markdown/CodeBlock.js @@ -0,0 +1,123 @@ +import { IconClipboard, IconDownload } from '@tabler/icons' +import { memo, useState } from 'react' +import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter' +import { oneDark } from 'react-syntax-highlighter/dist/esm/styles/prism' +import PropTypes from 'prop-types' +import { Box, IconButton, Popover, Typography } from '@mui/material' +import { useTheme } from '@mui/material/styles' + +const programmingLanguages = { + javascript: '.js', + python: '.py', + java: '.java', + c: '.c', + cpp: '.cpp', + 'c++': '.cpp', + 'c#': '.cs', + ruby: '.rb', + php: '.php', + swift: '.swift', + 'objective-c': '.m', + kotlin: '.kt', + typescript: '.ts', + go: '.go', + perl: '.pl', + rust: '.rs', + scala: '.scala', + haskell: '.hs', + lua: '.lua', + shell: '.sh', + sql: '.sql', + html: '.html', + css: '.css' +} + +export const CodeBlock = memo(({ language, chatflowid, isDialog, value }) => { + const theme = useTheme() + const [anchorEl, setAnchorEl] = useState(null) + const openPopOver = Boolean(anchorEl) + + const handleClosePopOver = () => { + setAnchorEl(null) + } + + const copyToClipboard = (event) => { + if (!navigator.clipboard || !navigator.clipboard.writeText) { + return + } + + navigator.clipboard.writeText(value) + setAnchorEl(event.currentTarget) + setTimeout(() => { + handleClosePopOver() + }, 1500) + } + + const downloadAsFile = () => { + const fileExtension = programmingLanguages[language] || '.file' + const suggestedFileName = `file-${chatflowid}${fileExtension}` + const fileName = suggestedFileName + + if (!fileName) { + // user pressed cancel on prompt + return + } + + const blob = new Blob([value], { type: 'text/plain' }) + const url = URL.createObjectURL(blob) + const link = document.createElement('a') + link.download = fileName + link.href = url + link.style.display = 'none' + document.body.appendChild(link) + link.click() + document.body.removeChild(link) + URL.revokeObjectURL(url) + } + + return ( +
+ +
+ {language} +
+ + + + + + Copied! + + + + + +
+
+ + + {value} + +
+ ) +}) +CodeBlock.displayName = 'CodeBlock' + +CodeBlock.propTypes = { + language: PropTypes.string, + chatflowid: PropTypes.string, + isDialog: PropTypes.bool, + value: PropTypes.string +} diff --git a/packages/ui/src/ui-component/markdown/MemoizedReactMarkdown.js b/packages/ui/src/ui-component/markdown/MemoizedReactMarkdown.js new file mode 100644 index 000000000..f9770a9f3 --- /dev/null +++ b/packages/ui/src/ui-component/markdown/MemoizedReactMarkdown.js @@ -0,0 +1,4 @@ +import { memo } from 'react' +import ReactMarkdown from 'react-markdown' + +export const MemoizedReactMarkdown = memo(ReactMarkdown, (prevProps, nextProps) => prevProps.children === nextProps.children) diff --git a/packages/ui/src/utils/genericHelper.js b/packages/ui/src/utils/genericHelper.js index c1dcb1086..fac832259 100644 --- a/packages/ui/src/utils/genericHelper.js +++ b/packages/ui/src/utils/genericHelper.js @@ -314,3 +314,23 @@ export const rearrangeToolsOrdering = (newValues, sourceNodeId) => { newValues.sort((a, b) => sortKey(a) - sortKey(b)) } + +export const throttle = (func, limit) => { + let lastFunc + let lastRan + + return (...args) => { + if (!lastRan) { + func(...args) + lastRan = Date.now() + } else { + clearTimeout(lastFunc) + lastFunc = setTimeout(() => { + if (Date.now() - lastRan >= limit) { + func(...args) + lastRan = Date.now() + } + }, limit - (Date.now() - lastRan)) + } + } +} diff --git a/packages/ui/src/views/canvas/index.js b/packages/ui/src/views/canvas/index.js index f71acbdfe..2d71f03ae 100644 --- a/packages/ui/src/views/canvas/index.js +++ b/packages/ui/src/views/canvas/index.js @@ -23,7 +23,7 @@ import ButtonEdge from './ButtonEdge' import CanvasHeader from './CanvasHeader' import AddNodes from './AddNodes' import ConfirmDialog from 'ui-component/dialog/ConfirmDialog' -import { ChatMessage } from 'views/chatmessage/ChatMessage' +import { ChatPopUp } from 'views/chatmessage/ChatPopUp' import { flowContext } from 'store/context/ReactFlowContext' // API @@ -514,7 +514,7 @@ const Canvas = () => { /> - + diff --git a/packages/ui/src/views/chatmessage/ChatExpandDialog.js b/packages/ui/src/views/chatmessage/ChatExpandDialog.js new file mode 100644 index 000000000..aa5cd5048 --- /dev/null +++ b/packages/ui/src/views/chatmessage/ChatExpandDialog.js @@ -0,0 +1,62 @@ +import { createPortal } from 'react-dom' +import PropTypes from 'prop-types' +import { useSelector } from 'react-redux' + +import { Dialog, DialogContent, DialogTitle, Button } from '@mui/material' +import { ChatMessage } from './ChatMessage' +import { StyledButton } from 'ui-component/button/StyledButton' +import { IconEraser } from '@tabler/icons' + +const ChatExpandDialog = ({ show, dialogProps, onClear, onCancel }) => { + const portalElement = document.getElementById('portal') + const customization = useSelector((state) => state.customization) + + const component = show ? ( + + +
+ {dialogProps.title} +
+ {customization.isDarkMode && ( + } + > + Clear Chat + + )} + {!customization.isDarkMode && ( + + )} +
+
+ + + +
+ ) : null + + return createPortal(component, portalElement) +} + +ChatExpandDialog.propTypes = { + show: PropTypes.bool, + dialogProps: PropTypes.object, + onClear: PropTypes.func, + onCancel: PropTypes.func +} + +export default ChatExpandDialog diff --git a/packages/ui/src/views/chatmessage/ChatMessage.css b/packages/ui/src/views/chatmessage/ChatMessage.css index a29e49ffd..9086fb137 100644 --- a/packages/ui/src/views/chatmessage/ChatMessage.css +++ b/packages/ui/src/views/chatmessage/ChatMessage.css @@ -80,7 +80,7 @@ } .markdownanswer code { - color: #15cb19; + color: #0ab126; font-weight: 500; white-space: pre-wrap !important; } @@ -92,6 +92,7 @@ .boticon, .usericon { + margin-top: 1rem; margin-right: 1rem; border-radius: 1rem; } @@ -119,3 +120,12 @@ justify-content: center; align-items: center; } + +.cloud-dialog { + width: 100%; + height: calc(100vh - 230px); + border-radius: 0.5rem; + display: flex; + justify-content: center; + align-items: center; +} diff --git a/packages/ui/src/views/chatmessage/ChatMessage.js b/packages/ui/src/views/chatmessage/ChatMessage.js index e50f5bd5e..e894f46b1 100644 --- a/packages/ui/src/views/chatmessage/ChatMessage.js +++ b/packages/ui/src/views/chatmessage/ChatMessage.js @@ -1,53 +1,38 @@ -import { useState, useRef, useEffect } from 'react' -import { useDispatch, useSelector } from 'react-redux' -import ReactMarkdown from 'react-markdown' +import { useState, useRef, useEffect, useCallback } from 'react' +import { useSelector } from 'react-redux' import PropTypes from 'prop-types' -import { enqueueSnackbar as enqueueSnackbarAction, closeSnackbar as closeSnackbarAction } from 'store/actions' +import socketIOClient from 'socket.io-client' +import { cloneDeep } from 'lodash' +import rehypeMathjax from 'rehype-mathjax' +import remarkGfm from 'remark-gfm' +import remarkMath from 'remark-math' -import { - ClickAwayListener, - Paper, - Popper, - CircularProgress, - OutlinedInput, - Divider, - InputAdornment, - IconButton, - Box, - Button -} from '@mui/material' +import { CircularProgress, OutlinedInput, Divider, InputAdornment, IconButton, Box } from '@mui/material' import { useTheme } from '@mui/material/styles' -import { IconMessage, IconX, IconSend, IconEraser } from '@tabler/icons' +import { IconSend } from '@tabler/icons' // project import -import { StyledFab } from 'ui-component/button/StyledFab' -import MainCard from 'ui-component/cards/MainCard' -import Transitions from 'ui-component/extended/Transitions' +import { CodeBlock } from 'ui-component/markdown/CodeBlock' +import { MemoizedReactMarkdown } from 'ui-component/markdown/MemoizedReactMarkdown' import './ChatMessage.css' // api import chatmessageApi from 'api/chatmessage' +import chatflowsApi from 'api/chatflows' import predictionApi from 'api/prediction' // Hooks import useApi from 'hooks/useApi' -import useConfirm from 'hooks/useConfirm' -import useNotifier from 'utils/useNotifier' -import { maxScroll } from 'store/constant' +// Const +import { baseURL, maxScroll } from 'store/constant' -export const ChatMessage = ({ chatflowid }) => { +export const ChatMessage = ({ open, chatflowid, isDialog }) => { const theme = useTheme() const customization = useSelector((state) => state.customization) - const { confirm } = useConfirm() - const dispatch = useDispatch() + const ps = useRef() - useNotifier() - const enqueueSnackbar = (...args) => dispatch(enqueueSnackbarAction(...args)) - const closeSnackbar = (...args) => dispatch(closeSnackbarAction(...args)) - - const [open, setOpen] = useState(false) const [userInput, setUserInput] = useState('') const [loading, setLoading] = useState(false) const [messages, setMessages] = useState([ @@ -56,72 +41,21 @@ export const ChatMessage = ({ chatflowid }) => { type: 'apiMessage' } ]) + const [socketIOClientId, setSocketIOClientId] = useState('') + const [isChatFlowAvailableToStream, setIsChatFlowAvailableToStream] = useState(false) const inputRef = useRef(null) - const anchorRef = useRef(null) - const prevOpen = useRef(open) const getChatmessageApi = useApi(chatmessageApi.getChatmessageFromChatflow) - - const handleClose = (event) => { - if (anchorRef.current && anchorRef.current.contains(event.target)) { - return - } - setOpen(false) - } - - const handleToggle = () => { - setOpen((prevOpen) => !prevOpen) - } - - const clearChat = async () => { - const confirmPayload = { - title: `Clear Chat History`, - description: `Are you sure you want to clear all chat history?`, - confirmButtonName: 'Clear', - cancelButtonName: 'Cancel' - } - const isConfirmed = await confirm(confirmPayload) - - if (isConfirmed) { - try { - await chatmessageApi.deleteChatmessage(chatflowid) - enqueueSnackbar({ - message: 'Succesfully cleared all chat history', - options: { - key: new Date().getTime() + Math.random(), - variant: 'success', - action: (key) => ( - - ) - } - }) - } catch (error) { - const errorData = error.response.data || `${error.response.status}: ${error.response.statusText}` - enqueueSnackbar({ - message: errorData, - options: { - key: new Date().getTime() + Math.random(), - variant: 'error', - persist: true, - action: (key) => ( - - ) - } - }) - } - } - } + const getIsChatflowStreamingApi = useApi(chatflowsApi.getIsChatflowStreaming) const scrollToBottom = () => { if (ps.current) { - ps.current.scrollTo({ top: maxScroll, behavior: 'smooth' }) + ps.current.scrollTo({ top: maxScroll }) } } + const onChange = useCallback((e) => setUserInput(e.target.value), [setUserInput]) + const addChatMessage = async (message, type) => { try { const newChatMessageBody = { @@ -135,6 +69,15 @@ export const ChatMessage = ({ chatflowid }) => { } } + const updateLastMessage = (text) => { + setMessages((prevMessages) => { + let allMessages = [...cloneDeep(prevMessages)] + if (allMessages[allMessages.length - 1].type === 'userMessage') return allMessages + allMessages[allMessages.length - 1].message += text + return allMessages + }) + } + // Handle errors const handleError = (message = 'Oops! There seems to be an error. Please try again.') => { message = message.replace(`Unable to parse JSON response from chat agent.\n\n`, '') @@ -143,7 +86,7 @@ export const ChatMessage = ({ chatflowid }) => { setLoading(false) setUserInput('') setTimeout(() => { - inputRef.current.focus() + inputRef.current?.focus() }, 100) } @@ -161,18 +104,22 @@ export const ChatMessage = ({ chatflowid }) => { // Send user question and history to API try { - const response = await predictionApi.sendMessageAndGetPrediction(chatflowid, { + const params = { question: userInput, history: messages.filter((msg) => msg.message !== 'Hi there! How can I help?') - }) + } + if (isChatFlowAvailableToStream) params.socketIOClientId = socketIOClientId + + const response = await predictionApi.sendMessageAndGetPrediction(chatflowid, params) + if (response.data) { const data = response.data - setMessages((prevMessages) => [...prevMessages, { message: data, type: 'apiMessage' }]) + if (!isChatFlowAvailableToStream) setMessages((prevMessages) => [...prevMessages, { message: data, type: 'apiMessage' }]) addChatMessage(data, 'apiMessage') setLoading(false) setUserInput('') setTimeout(() => { - inputRef.current.focus() + inputRef.current?.focus() scrollToBottom() }, 100) } @@ -210,22 +157,47 @@ export const ChatMessage = ({ chatflowid }) => { // eslint-disable-next-line react-hooks/exhaustive-deps }, [getChatmessageApi.data]) + // Get chatflow streaming capability + useEffect(() => { + if (getIsChatflowStreamingApi.data) { + setIsChatFlowAvailableToStream(getIsChatflowStreamingApi.data?.isStreaming ?? false) + } + + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [getIsChatflowStreamingApi.data]) + // Auto scroll chat to bottom useEffect(() => { scrollToBottom() }, [messages]) useEffect(() => { - if (prevOpen.current === true && open === false) { - anchorRef.current.focus() + if (isDialog && inputRef) { + setTimeout(() => { + inputRef.current?.focus() + }, 100) } + }, [isDialog, inputRef]) + useEffect(() => { + let socket if (open && chatflowid) { getChatmessageApi.request(chatflowid) + getIsChatflowStreamingApi.request(chatflowid) scrollToBottom() - } - prevOpen.current = open + socket = socketIOClient(baseURL) + + socket.on('connect', () => { + setSocketIOClientId(socket.id) + }) + + socket.on('start', () => { + setMessages((prevMessages) => [...prevMessages, { message: '', type: 'apiMessage' }]) + }) + + socket.on('token', updateLastMessage) + } return () => { setUserInput('') @@ -236,6 +208,10 @@ export const ChatMessage = ({ chatflowid }) => { type: 'apiMessage' } ]) + if (socket) { + socket.disconnect() + setSocketIOClientId('') + } } // eslint-disable-next-line react-hooks/exhaustive-deps @@ -243,151 +219,121 @@ export const ChatMessage = ({ chatflowid }) => { return ( <> - - {open ? : } - - {open && ( - - - - )} - +
+ {messages && + messages.map((message, index) => { + return ( + // The latest message sent by the user will be animated while waiting for a response + + {/* Display the correct icon depending on the message type */} + {message.type === 'apiMessage' ? ( + AI + ) : ( + Me + )} +
+ {/* Messages are being rendered in Markdown format */} + + ) : ( + + {children} + + ) + } + }} + > + {message.message} + +
+
+ ) + })} +
+ + +
+
+
+ + + {loading ? ( +
+ +
+ ) : ( + // Send icon SVG in input field + + )} +
+ } - } - ] - }} - sx={{ zIndex: 1000 }} - > - {({ TransitionProps }) => ( - - - - -
-
- {messages.map((message, index) => { - return ( - // The latest message sent by the user will be animated while waiting for a response - - {/* Display the correct icon depending on the message type */} - {message.type === 'apiMessage' ? ( - AI - ) : ( - Me - )} -
- {/* Messages are being rendered in Markdown format */} - {message.message} -
-
- ) - })} -
-
- -
-
- - setUserInput(e.target.value)} - endAdornment={ - - - {loading ? ( -
- -
- ) : ( - // Send icon SVG in input field - - )} -
-
- } - /> - -
-
-
-
-
-
- )} - + /> + +
+
) } -ChatMessage.propTypes = { chatflowid: PropTypes.string } +ChatMessage.propTypes = { + open: PropTypes.bool, + chatflowid: PropTypes.string, + isDialog: PropTypes.bool +} diff --git a/packages/ui/src/views/chatmessage/ChatPopUp.js b/packages/ui/src/views/chatmessage/ChatPopUp.js new file mode 100644 index 000000000..93050c3a8 --- /dev/null +++ b/packages/ui/src/views/chatmessage/ChatPopUp.js @@ -0,0 +1,208 @@ +import { useState, useRef, useEffect } from 'react' +import { useDispatch } from 'react-redux' +import PropTypes from 'prop-types' + +import { ClickAwayListener, Paper, Popper, Button } from '@mui/material' +import { useTheme } from '@mui/material/styles' +import { IconMessage, IconX, IconEraser, IconArrowsMaximize } from '@tabler/icons' + +// project import +import { StyledFab } from 'ui-component/button/StyledFab' +import MainCard from 'ui-component/cards/MainCard' +import Transitions from 'ui-component/extended/Transitions' +import { ChatMessage } from './ChatMessage' +import ChatExpandDialog from './ChatExpandDialog' + +// api +import chatmessageApi from 'api/chatmessage' + +// Hooks +import useConfirm from 'hooks/useConfirm' +import useNotifier from 'utils/useNotifier' + +// Const +import { enqueueSnackbar as enqueueSnackbarAction, closeSnackbar as closeSnackbarAction } from 'store/actions' + +export const ChatPopUp = ({ chatflowid }) => { + const theme = useTheme() + const { confirm } = useConfirm() + const dispatch = useDispatch() + + useNotifier() + const enqueueSnackbar = (...args) => dispatch(enqueueSnackbarAction(...args)) + const closeSnackbar = (...args) => dispatch(closeSnackbarAction(...args)) + + const [open, setOpen] = useState(false) + const [showExpandDialog, setShowExpandDialog] = useState(false) + const [expandDialogProps, setExpandDialogProps] = useState({}) + + const anchorRef = useRef(null) + const prevOpen = useRef(open) + + const handleClose = (event) => { + if (anchorRef.current && anchorRef.current.contains(event.target)) { + return + } + setOpen(false) + } + + const handleToggle = () => { + setOpen((prevOpen) => !prevOpen) + } + + const expandChat = () => { + const props = { + open: true, + chatflowid: chatflowid + } + setExpandDialogProps(props) + setShowExpandDialog(true) + } + + const resetChatDialog = () => { + const props = { + ...expandDialogProps, + open: false + } + setExpandDialogProps(props) + setTimeout(() => { + const resetProps = { + ...expandDialogProps, + open: true + } + setExpandDialogProps(resetProps) + }, 500) + } + + const clearChat = async () => { + const confirmPayload = { + title: `Clear Chat History`, + description: `Are you sure you want to clear all chat history?`, + confirmButtonName: 'Clear', + cancelButtonName: 'Cancel' + } + const isConfirmed = await confirm(confirmPayload) + + if (isConfirmed) { + try { + await chatmessageApi.deleteChatmessage(chatflowid) + resetChatDialog() + enqueueSnackbar({ + message: 'Succesfully cleared all chat history', + options: { + key: new Date().getTime() + Math.random(), + variant: 'success', + action: (key) => ( + + ) + } + }) + } catch (error) { + const errorData = error.response.data || `${error.response.status}: ${error.response.statusText}` + enqueueSnackbar({ + message: errorData, + options: { + key: new Date().getTime() + Math.random(), + variant: 'error', + persist: true, + action: (key) => ( + + ) + } + }) + } + } + } + + useEffect(() => { + if (prevOpen.current === true && open === false) { + anchorRef.current.focus() + } + prevOpen.current = open + + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [open, chatflowid]) + + return ( + <> + + {open ? : } + + {open && ( + + + + )} + {open && ( + + + + )} + + {({ TransitionProps }) => ( + + + + + + + + + + )} + + setShowExpandDialog(false)} + > + + ) +} + +ChatPopUp.propTypes = { chatflowid: PropTypes.string }