From b071790a5a5e607e7d2a97c394eeeeb19144b99e Mon Sep 17 00:00:00 2001 From: Henry Date: Thu, 8 Jun 2023 23:51:34 +0100 Subject: [PATCH] add return source documents functioanality --- .../ConversationalRetrievalQAChain.ts | 87 ++++++-- packages/components/package.json | 2 +- packages/components/src/Interface.ts | 2 +- packages/components/src/utils.ts | 11 +- packages/server/src/Interface.ts | 1 + packages/server/src/entity/ChatMessage.ts | 3 + packages/server/src/index.ts | 2 +- .../MainLayout/Header/ProfileSection/index.js | 17 +- .../ui/src/ui-component/dialog/AboutDialog.js | 85 ++++++++ .../ui-component/dialog/SourceDocDialog.js | 57 +++++ .../ui/src/views/chatmessage/ChatMessage.js | 195 +++++++++++------- 11 files changed, 371 insertions(+), 91 deletions(-) create mode 100644 packages/ui/src/ui-component/dialog/AboutDialog.js create mode 100644 packages/ui/src/ui-component/dialog/SourceDocDialog.js diff --git a/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts b/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts index 659e1e2e6..3b7e1413f 100644 --- a/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts +++ b/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts @@ -2,7 +2,9 @@ import { BaseLanguageModel } from 'langchain/base_language' import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface' import { CustomChainHandler, getBaseClasses } from '../../../src/utils' import { ConversationalRetrievalQAChain } from 'langchain/chains' -import { BaseRetriever } from 'langchain/schema' +import { AIChatMessage, BaseRetriever, HumanChatMessage } from 'langchain/schema' +import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory' +import { PromptTemplate } from 'langchain/prompts' const default_qa_template = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. @@ -47,6 +49,12 @@ class ConversationalRetrievalQAChain_Chains implements INode { name: 'vectorStoreRetriever', type: 'BaseRetriever' }, + { + label: 'Return Source Documents', + name: 'returnSourceDocuments', + type: 'boolean', + optional: true + }, { label: 'System Message', name: 'systemMessagePrompt', @@ -56,6 +64,31 @@ class ConversationalRetrievalQAChain_Chains implements INode { optional: true, placeholder: 'I want you to act as a document that I am having a conversation with. Your name is "AI Assistant". You will provide me with answers from the given info. If the answer is not included, say exactly "Hmm, I am not sure." and stop after that. Refuse to answer any question not about the info. Never break character.' + }, + { + label: 'Chain Option', + name: 'chainOption', + type: 'options', + options: [ + { + label: 'MapReduceDocumentsChain', + name: 'map_reduce', + description: + 'Suitable for QA tasks over larger documents and can run the preprocessing step in parallel, reducing the running time' + }, + { + label: 'RefineDocumentsChain', + name: 'refine', + description: 'Suitable for QA tasks over a large number of documents.' + }, + { + label: 'StuffDocumentsChain', + name: 'stuff', + description: 'Suitable for QA tasks over a small number of documents.' + } + ], + additionalParams: true, + optional: true } ] } @@ -64,44 +97,64 @@ class ConversationalRetrievalQAChain_Chains implements INode { const model = nodeData.inputs?.model as BaseLanguageModel const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string + const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean + const chainOption = nodeData.inputs?.chainOption as string - const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, { + const obj: any = { verbose: process.env.DEBUG === 'true' ? true : false, - qaTemplate: systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template - }) + qaChainOptions: { + type: 'stuff', + prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template) + }, + memory: new BufferMemory({ + memoryKey: 'chat_history', + inputKey: 'question', + outputKey: 'text', + returnMessages: true + }) + } + if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments + if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption } + + const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj) return chain } - async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const chain = nodeData.instance as ConversationalRetrievalQAChain + const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean let model = nodeData.inputs?.model // Temporary fix: https://github.com/hwchase17/langchainjs/issues/754 model.streaming = false chain.questionGeneratorChain.llm = model - let chatHistory = '' + const obj = { question: input } - if (options && options.chatHistory) { + if (chain.memory && options && options.chatHistory) { + const chatHistory = [] const histories: IMessage[] = options.chatHistory - chatHistory = histories - .map((item) => { - return item.message - }) - .join('') - } + const memory = chain.memory as BaseChatMemory - const obj = { - question: input, - chat_history: chatHistory ? chatHistory : [] + for (const message of histories) { + if (message.type === 'apiMessage') { + chatHistory.push(new AIChatMessage(message.message)) + } else if (message.type === 'userMessage') { + chatHistory.push(new HumanChatMessage(message.message)) + } + } + memory.chatHistory = new ChatMessageHistory(chatHistory) + chain.memory = memory } if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, undefined, returnSourceDocuments) const res = await chain.call(obj, [handler]) + if (res.text && res.sourceDocuments) return res return res?.text } else { const res = await chain.call(obj) + if (res.text && res.sourceDocuments) return res return res?.text } } diff --git a/packages/components/package.json b/packages/components/package.json index 161792b2d..46b441bdb 100644 --- a/packages/components/package.json +++ b/packages/components/package.json @@ -31,7 +31,7 @@ "faiss-node": "^0.2.1", "form-data": "^4.0.0", "graphql": "^16.6.0", - "langchain": "^0.0.84", + "langchain": "^0.0.91", "linkifyjs": "^4.1.1", "mammoth": "^1.5.1", "moment": "^2.29.3", diff --git a/packages/components/src/Interface.ts b/packages/components/src/Interface.ts index c14119399..bd94cca89 100644 --- a/packages/components/src/Interface.ts +++ b/packages/components/src/Interface.ts @@ -75,7 +75,7 @@ export interface INode extends INodeProperties { inputs?: INodeParams[] output?: INodeOutputsValue[] init?(nodeData: INodeData, input: string, options?: ICommonObject): Promise - run?(nodeData: INodeData, input: string, options?: ICommonObject): Promise + run?(nodeData: INodeData, input: string, options?: ICommonObject): Promise } export interface INodeData extends INodeProperties { diff --git a/packages/components/src/utils.ts b/packages/components/src/utils.ts index 08d32bab1..3e8015125 100644 --- a/packages/components/src/utils.ts +++ b/packages/components/src/utils.ts @@ -4,6 +4,7 @@ import * as fs from 'fs' import * as path from 'path' import { BaseCallbackHandler } from 'langchain/callbacks' import { Server } from 'socket.io' +import { ChainValues } from 'langchain/dist/schema' export const numberOrExpressionRegex = '^(\\d+\\.?\\d*|{{.*}})$' //return true if string consists only numbers OR expression {{}} export const notEmptyRegex = '(.|\\s)*\\S(.|\\s)*' //return true if string is not empty or blank @@ -208,12 +209,14 @@ export class CustomChainHandler extends BaseCallbackHandler { socketIO: Server socketIOClientId = '' skipK = 0 // Skip streaming for first K numbers of handleLLMStart + returnSourceDocuments = false - constructor(socketIO: Server, socketIOClientId: string, skipK?: number) { + constructor(socketIO: Server, socketIOClientId: string, skipK?: number, returnSourceDocuments?: boolean) { super() this.socketIO = socketIO this.socketIOClientId = socketIOClientId this.skipK = skipK ?? this.skipK + this.returnSourceDocuments = returnSourceDocuments ?? this.returnSourceDocuments } handleLLMStart() { @@ -233,4 +236,10 @@ export class CustomChainHandler extends BaseCallbackHandler { handleLLMEnd() { this.socketIO.to(this.socketIOClientId).emit('end') } + + handleChainEnd(outputs: ChainValues): void | Promise { + if (this.returnSourceDocuments) { + this.socketIO.to(this.socketIOClientId).emit('sourceDocuments', outputs?.sourceDocuments) + } + } } diff --git a/packages/server/src/Interface.ts b/packages/server/src/Interface.ts index 0dede0248..b6876df3f 100644 --- a/packages/server/src/Interface.ts +++ b/packages/server/src/Interface.ts @@ -21,6 +21,7 @@ export interface IChatMessage { content: string chatflowid: string createdDate: Date + sourceDocuments: string } export interface IComponentNodes { diff --git a/packages/server/src/entity/ChatMessage.ts b/packages/server/src/entity/ChatMessage.ts index 3380c86cd..236dc5f93 100644 --- a/packages/server/src/entity/ChatMessage.ts +++ b/packages/server/src/entity/ChatMessage.ts @@ -17,6 +17,9 @@ export class ChatMessage implements IChatMessage { @Column() content: string + @Column({ nullable: true }) + sourceDocuments: string + @CreateDateColumn() createdDate: Date } diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 0e030734a..3a0c64cda 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -90,7 +90,7 @@ export class App { const basicAuthMiddleware = basicAuth({ users: { [username]: password } }) - const whitelistURLs = ['/api/v1/prediction/', '/api/v1/node-icon/'] + const whitelistURLs = ['/api/v1/prediction/', '/api/v1/node-icon/', '/api/v1/chatflows-streaming'] this.app.use((req, res, next) => { if (req.url.includes('/api/v1/')) { whitelistURLs.some((url) => req.url.includes(url)) ? next() : basicAuthMiddleware(req, res, next) diff --git a/packages/ui/src/layout/MainLayout/Header/ProfileSection/index.js b/packages/ui/src/layout/MainLayout/Header/ProfileSection/index.js index f6f6a7307..41de3dd44 100644 --- a/packages/ui/src/layout/MainLayout/Header/ProfileSection/index.js +++ b/packages/ui/src/layout/MainLayout/Header/ProfileSection/index.js @@ -27,9 +27,10 @@ import PerfectScrollbar from 'react-perfect-scrollbar' import MainCard from 'ui-component/cards/MainCard' import Transitions from 'ui-component/extended/Transitions' import { BackdropLoader } from 'ui-component/loading/BackdropLoader' +import AboutDialog from 'ui-component/dialog/AboutDialog' // assets -import { IconLogout, IconSettings, IconFileExport, IconFileDownload } from '@tabler/icons' +import { IconLogout, IconSettings, IconFileExport, IconFileDownload, IconInfoCircle } from '@tabler/icons' // API import databaseApi from 'api/database' @@ -49,6 +50,7 @@ const ProfileSection = ({ username, handleLogout }) => { const [open, setOpen] = useState(false) const [loading, setLoading] = useState(false) + const [aboutDialogOpen, setAboutDialogOpen] = useState(false) const anchorRef = useRef(null) const uploadRef = useRef(null) @@ -215,6 +217,18 @@ const ProfileSection = ({ username, handleLogout }) => { Export Database} /> + { + setOpen(false) + setAboutDialogOpen(true) + }} + > + + + + About Flowise} /> + {localStorage.getItem('username') && localStorage.getItem('password') && ( { handleFileUpload(e)} /> + setAboutDialogOpen(false)} /> ) } diff --git a/packages/ui/src/ui-component/dialog/AboutDialog.js b/packages/ui/src/ui-component/dialog/AboutDialog.js new file mode 100644 index 000000000..54c077d18 --- /dev/null +++ b/packages/ui/src/ui-component/dialog/AboutDialog.js @@ -0,0 +1,85 @@ +import { createPortal } from 'react-dom' +import { useState, useEffect } from 'react' +import PropTypes from 'prop-types' +import { Dialog, DialogContent, DialogTitle, TableContainer, Table, TableHead, TableRow, TableCell, TableBody, Paper } from '@mui/material' +import moment from 'moment' +import axios from 'axios' + +const fetchLatestVer = async ({ api }) => { + let apiReturn = await axios + .get(api) + .then(async function (response) { + return response.data + }) + .catch(function (error) { + console.error(error) + }) + return apiReturn +} + +const AboutDialog = ({ show, onCancel }) => { + const portalElement = document.getElementById('portal') + + const [data, setData] = useState({}) + + useEffect(() => { + if (show) { + const fetchData = async (api) => { + let response = await fetchLatestVer({ api }) + setData(response) + } + + fetchData('https://api.github.com/repos/FlowiseAI/Flowise/releases/latest') + } + + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [show]) + + const component = show ? ( + + + Flowise Version + + + {data && ( + + + + + Latest Version + Published At + + + + + + + {data.name} + + + {moment(data.published_at).fromNow()} + + +
+
+ )} +
+
+ ) : null + + return createPortal(component, portalElement) +} + +AboutDialog.propTypes = { + show: PropTypes.bool, + onCancel: PropTypes.func +} + +export default AboutDialog diff --git a/packages/ui/src/ui-component/dialog/SourceDocDialog.js b/packages/ui/src/ui-component/dialog/SourceDocDialog.js new file mode 100644 index 000000000..a088a6c49 --- /dev/null +++ b/packages/ui/src/ui-component/dialog/SourceDocDialog.js @@ -0,0 +1,57 @@ +import { createPortal } from 'react-dom' +import { useState, useEffect } from 'react' +import { useSelector } from 'react-redux' +import PropTypes from 'prop-types' +import { Dialog, DialogContent, DialogTitle } from '@mui/material' +import ReactJson from 'react-json-view' + +const SourceDocDialog = ({ show, dialogProps, onCancel }) => { + const portalElement = document.getElementById('portal') + const customization = useSelector((state) => state.customization) + + const [data, setData] = useState({}) + + useEffect(() => { + if (dialogProps.data) setData(dialogProps.data) + + return () => { + setData({}) + } + }, [dialogProps]) + + const component = show ? ( + + + Source Document + + + + + + ) : null + + return createPortal(component, portalElement) +} + +SourceDocDialog.propTypes = { + show: PropTypes.bool, + dialogProps: PropTypes.object, + onCancel: PropTypes.func +} + +export default SourceDocDialog diff --git a/packages/ui/src/views/chatmessage/ChatMessage.js b/packages/ui/src/views/chatmessage/ChatMessage.js index e894f46b1..65bad2123 100644 --- a/packages/ui/src/views/chatmessage/ChatMessage.js +++ b/packages/ui/src/views/chatmessage/ChatMessage.js @@ -7,13 +7,14 @@ import rehypeMathjax from 'rehype-mathjax' import remarkGfm from 'remark-gfm' import remarkMath from 'remark-math' -import { CircularProgress, OutlinedInput, Divider, InputAdornment, IconButton, Box } from '@mui/material' +import { CircularProgress, OutlinedInput, Divider, InputAdornment, IconButton, Box, Chip } from '@mui/material' import { useTheme } from '@mui/material/styles' import { IconSend } from '@tabler/icons' // project import import { CodeBlock } from 'ui-component/markdown/CodeBlock' import { MemoizedReactMarkdown } from 'ui-component/markdown/MemoizedReactMarkdown' +import SourceDocDialog from 'ui-component/dialog/SourceDocDialog' import './ChatMessage.css' // api @@ -43,11 +44,18 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { ]) const [socketIOClientId, setSocketIOClientId] = useState('') const [isChatFlowAvailableToStream, setIsChatFlowAvailableToStream] = useState(false) + const [sourceDialogOpen, setSourceDialogOpen] = useState(false) + const [sourceDialogProps, setSourceDialogProps] = useState({}) const inputRef = useRef(null) const getChatmessageApi = useApi(chatmessageApi.getChatmessageFromChatflow) const getIsChatflowStreamingApi = useApi(chatflowsApi.getIsChatflowStreaming) + const onSourceDialogClick = (data) => { + setSourceDialogProps({ data }) + setSourceDialogOpen(true) + } + const scrollToBottom = () => { if (ps.current) { ps.current.scrollTo({ top: maxScroll }) @@ -56,13 +64,14 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { const onChange = useCallback((e) => setUserInput(e.target.value), [setUserInput]) - const addChatMessage = async (message, type) => { + const addChatMessage = async (message, type, sourceDocuments) => { try { const newChatMessageBody = { role: type, content: message, chatflowid: chatflowid } + if (sourceDocuments) newChatMessageBody.sourceDocuments = JSON.stringify(sourceDocuments) await chatmessageApi.createNewChatmessage(chatflowid, newChatMessageBody) } catch (error) { console.error(error) @@ -78,6 +87,15 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { }) } + const updateLastMessageSourceDocuments = (sourceDocuments) => { + setMessages((prevMessages) => { + let allMessages = [...cloneDeep(prevMessages)] + if (allMessages[allMessages.length - 1].type === 'userMessage') return allMessages + allMessages[allMessages.length - 1].sourceDocuments = sourceDocuments + return allMessages + }) + } + // Handle errors const handleError = (message = 'Oops! There seems to be an error. Please try again.') => { message = message.replace(`Unable to parse JSON response from chat agent.\n\n`, '') @@ -114,8 +132,20 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { if (response.data) { const data = response.data - if (!isChatFlowAvailableToStream) setMessages((prevMessages) => [...prevMessages, { message: data, type: 'apiMessage' }]) - addChatMessage(data, 'apiMessage') + if (typeof data === 'object' && data.text && data.sourceDocuments) { + if (!isChatFlowAvailableToStream) { + setMessages((prevMessages) => [ + ...prevMessages, + { message: data.text, sourceDocuments: data.sourceDocuments, type: 'apiMessage' } + ]) + } + addChatMessage(data.text, 'apiMessage', data.sourceDocuments) + } else { + if (!isChatFlowAvailableToStream) { + setMessages((prevMessages) => [...prevMessages, { message: data, type: 'apiMessage' }]) + } + addChatMessage(data, 'apiMessage') + } setLoading(false) setUserInput('') setTimeout(() => { @@ -146,10 +176,12 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { if (getChatmessageApi.data) { const loadedMessages = [] for (const message of getChatmessageApi.data) { - loadedMessages.push({ + const obj = { message: message.content, type: message.role - }) + } + if (message.sourceDocuments) obj.sourceDocuments = JSON.parse(message.sourceDocuments) + loadedMessages.push(obj) } setMessages((prevMessages) => [...prevMessages, ...loadedMessages]) } @@ -196,6 +228,8 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { setMessages((prevMessages) => [...prevMessages, { message: '', type: 'apiMessage' }]) }) + socket.on('sourceDocuments', updateLastMessageSourceDocuments) + socket.on('token', updateLastMessage) } @@ -225,69 +259,91 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { messages.map((message, index) => { return ( // The latest message sent by the user will be animated while waiting for a response - - {/* Display the correct icon depending on the message type */} - {message.type === 'apiMessage' ? ( - AI - ) : ( - Me - )} -
- {/* Messages are being rendered in Markdown format */} - - ) : ( - - {children} - - ) - } - }} - > - {message.message} - -
-
+ <> + + {/* Display the correct icon depending on the message type */} + {message.type === 'apiMessage' ? ( + AI + ) : ( + Me + )} +
+
+ {/* Messages are being rendered in Markdown format */} + + ) : ( + + {children} + + ) + } + }} + > + {message.message} + +
+ {message.sourceDocuments && ( +
+ {message.sourceDocuments.map((source, index) => { + return ( + onSourceDialogClick(source)} + /> + ) + })} +
+ )} +
+
+ ) })} @@ -328,6 +384,7 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { + setSourceDialogOpen(false)} /> ) }