add return source documents functioanality

This commit is contained in:
Henry 2023-06-08 23:51:34 +01:00
parent ab00214ec2
commit b071790a5a
11 changed files with 371 additions and 91 deletions

View File

@ -2,7 +2,9 @@ import { BaseLanguageModel } from 'langchain/base_language'
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
import { ConversationalRetrievalQAChain } from 'langchain/chains'
import { BaseRetriever } from 'langchain/schema'
import { AIChatMessage, BaseRetriever, HumanChatMessage } from 'langchain/schema'
import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory'
import { PromptTemplate } from 'langchain/prompts'
const default_qa_template = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
@ -47,6 +49,12 @@ class ConversationalRetrievalQAChain_Chains implements INode {
name: 'vectorStoreRetriever',
type: 'BaseRetriever'
},
{
label: 'Return Source Documents',
name: 'returnSourceDocuments',
type: 'boolean',
optional: true
},
{
label: 'System Message',
name: 'systemMessagePrompt',
@ -56,6 +64,31 @@ class ConversationalRetrievalQAChain_Chains implements INode {
optional: true,
placeholder:
'I want you to act as a document that I am having a conversation with. Your name is "AI Assistant". You will provide me with answers from the given info. If the answer is not included, say exactly "Hmm, I am not sure." and stop after that. Refuse to answer any question not about the info. Never break character.'
},
{
label: 'Chain Option',
name: 'chainOption',
type: 'options',
options: [
{
label: 'MapReduceDocumentsChain',
name: 'map_reduce',
description:
'Suitable for QA tasks over larger documents and can run the preprocessing step in parallel, reducing the running time'
},
{
label: 'RefineDocumentsChain',
name: 'refine',
description: 'Suitable for QA tasks over a large number of documents.'
},
{
label: 'StuffDocumentsChain',
name: 'stuff',
description: 'Suitable for QA tasks over a small number of documents.'
}
],
additionalParams: true,
optional: true
}
]
}
@ -64,44 +97,64 @@ class ConversationalRetrievalQAChain_Chains implements INode {
const model = nodeData.inputs?.model as BaseLanguageModel
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const chainOption = nodeData.inputs?.chainOption as string
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, {
const obj: any = {
verbose: process.env.DEBUG === 'true' ? true : false,
qaTemplate: systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template
})
qaChainOptions: {
type: 'stuff',
prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template)
},
memory: new BufferMemory({
memoryKey: 'chat_history',
inputKey: 'question',
outputKey: 'text',
returnMessages: true
})
}
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj)
return chain
}
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
const chain = nodeData.instance as ConversationalRetrievalQAChain
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
let model = nodeData.inputs?.model
// Temporary fix: https://github.com/hwchase17/langchainjs/issues/754
model.streaming = false
chain.questionGeneratorChain.llm = model
let chatHistory = ''
const obj = { question: input }
if (options && options.chatHistory) {
if (chain.memory && options && options.chatHistory) {
const chatHistory = []
const histories: IMessage[] = options.chatHistory
chatHistory = histories
.map((item) => {
return item.message
})
.join('')
}
const memory = chain.memory as BaseChatMemory
const obj = {
question: input,
chat_history: chatHistory ? chatHistory : []
for (const message of histories) {
if (message.type === 'apiMessage') {
chatHistory.push(new AIChatMessage(message.message))
} else if (message.type === 'userMessage') {
chatHistory.push(new HumanChatMessage(message.message))
}
}
memory.chatHistory = new ChatMessageHistory(chatHistory)
chain.memory = memory
}
if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, undefined, returnSourceDocuments)
const res = await chain.call(obj, [handler])
if (res.text && res.sourceDocuments) return res
return res?.text
} else {
const res = await chain.call(obj)
if (res.text && res.sourceDocuments) return res
return res?.text
}
}

View File

@ -31,7 +31,7 @@
"faiss-node": "^0.2.1",
"form-data": "^4.0.0",
"graphql": "^16.6.0",
"langchain": "^0.0.84",
"langchain": "^0.0.91",
"linkifyjs": "^4.1.1",
"mammoth": "^1.5.1",
"moment": "^2.29.3",

View File

@ -75,7 +75,7 @@ export interface INode extends INodeProperties {
inputs?: INodeParams[]
output?: INodeOutputsValue[]
init?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<any>
run?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<string>
run?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<string | ICommonObject>
}
export interface INodeData extends INodeProperties {

View File

@ -4,6 +4,7 @@ import * as fs from 'fs'
import * as path from 'path'
import { BaseCallbackHandler } from 'langchain/callbacks'
import { Server } from 'socket.io'
import { ChainValues } from 'langchain/dist/schema'
export const numberOrExpressionRegex = '^(\\d+\\.?\\d*|{{.*}})$' //return true if string consists only numbers OR expression {{}}
export const notEmptyRegex = '(.|\\s)*\\S(.|\\s)*' //return true if string is not empty or blank
@ -208,12 +209,14 @@ export class CustomChainHandler extends BaseCallbackHandler {
socketIO: Server
socketIOClientId = ''
skipK = 0 // Skip streaming for first K numbers of handleLLMStart
returnSourceDocuments = false
constructor(socketIO: Server, socketIOClientId: string, skipK?: number) {
constructor(socketIO: Server, socketIOClientId: string, skipK?: number, returnSourceDocuments?: boolean) {
super()
this.socketIO = socketIO
this.socketIOClientId = socketIOClientId
this.skipK = skipK ?? this.skipK
this.returnSourceDocuments = returnSourceDocuments ?? this.returnSourceDocuments
}
handleLLMStart() {
@ -233,4 +236,10 @@ export class CustomChainHandler extends BaseCallbackHandler {
handleLLMEnd() {
this.socketIO.to(this.socketIOClientId).emit('end')
}
handleChainEnd(outputs: ChainValues): void | Promise<void> {
if (this.returnSourceDocuments) {
this.socketIO.to(this.socketIOClientId).emit('sourceDocuments', outputs?.sourceDocuments)
}
}
}

View File

@ -21,6 +21,7 @@ export interface IChatMessage {
content: string
chatflowid: string
createdDate: Date
sourceDocuments: string
}
export interface IComponentNodes {

View File

@ -17,6 +17,9 @@ export class ChatMessage implements IChatMessage {
@Column()
content: string
@Column({ nullable: true })
sourceDocuments: string
@CreateDateColumn()
createdDate: Date
}

View File

@ -90,7 +90,7 @@ export class App {
const basicAuthMiddleware = basicAuth({
users: { [username]: password }
})
const whitelistURLs = ['/api/v1/prediction/', '/api/v1/node-icon/']
const whitelistURLs = ['/api/v1/prediction/', '/api/v1/node-icon/', '/api/v1/chatflows-streaming']
this.app.use((req, res, next) => {
if (req.url.includes('/api/v1/')) {
whitelistURLs.some((url) => req.url.includes(url)) ? next() : basicAuthMiddleware(req, res, next)

View File

@ -27,9 +27,10 @@ import PerfectScrollbar from 'react-perfect-scrollbar'
import MainCard from 'ui-component/cards/MainCard'
import Transitions from 'ui-component/extended/Transitions'
import { BackdropLoader } from 'ui-component/loading/BackdropLoader'
import AboutDialog from 'ui-component/dialog/AboutDialog'
// assets
import { IconLogout, IconSettings, IconFileExport, IconFileDownload } from '@tabler/icons'
import { IconLogout, IconSettings, IconFileExport, IconFileDownload, IconInfoCircle } from '@tabler/icons'
// API
import databaseApi from 'api/database'
@ -49,6 +50,7 @@ const ProfileSection = ({ username, handleLogout }) => {
const [open, setOpen] = useState(false)
const [loading, setLoading] = useState(false)
const [aboutDialogOpen, setAboutDialogOpen] = useState(false)
const anchorRef = useRef(null)
const uploadRef = useRef(null)
@ -215,6 +217,18 @@ const ProfileSection = ({ username, handleLogout }) => {
</ListItemIcon>
<ListItemText primary={<Typography variant='body2'>Export Database</Typography>} />
</ListItemButton>
<ListItemButton
sx={{ borderRadius: `${customization.borderRadius}px` }}
onClick={() => {
setOpen(false)
setAboutDialogOpen(true)
}}
>
<ListItemIcon>
<IconInfoCircle stroke={1.5} size='1.3rem' />
</ListItemIcon>
<ListItemText primary={<Typography variant='body2'>About Flowise</Typography>} />
</ListItemButton>
{localStorage.getItem('username') && localStorage.getItem('password') && (
<ListItemButton
sx={{ borderRadius: `${customization.borderRadius}px` }}
@ -237,6 +251,7 @@ const ProfileSection = ({ username, handleLogout }) => {
</Popper>
<input ref={uploadRef} type='file' hidden accept='.json' onChange={(e) => handleFileUpload(e)} />
<BackdropLoader open={loading} />
<AboutDialog show={aboutDialogOpen} onCancel={() => setAboutDialogOpen(false)} />
</>
)
}

View File

@ -0,0 +1,85 @@
import { createPortal } from 'react-dom'
import { useState, useEffect } from 'react'
import PropTypes from 'prop-types'
import { Dialog, DialogContent, DialogTitle, TableContainer, Table, TableHead, TableRow, TableCell, TableBody, Paper } from '@mui/material'
import moment from 'moment'
import axios from 'axios'
const fetchLatestVer = async ({ api }) => {
let apiReturn = await axios
.get(api)
.then(async function (response) {
return response.data
})
.catch(function (error) {
console.error(error)
})
return apiReturn
}
const AboutDialog = ({ show, onCancel }) => {
const portalElement = document.getElementById('portal')
const [data, setData] = useState({})
useEffect(() => {
if (show) {
const fetchData = async (api) => {
let response = await fetchLatestVer({ api })
setData(response)
}
fetchData('https://api.github.com/repos/FlowiseAI/Flowise/releases/latest')
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [show])
const component = show ? (
<Dialog
onClose={onCancel}
open={show}
fullWidth
maxWidth='sm'
aria-labelledby='alert-dialog-title'
aria-describedby='alert-dialog-description'
>
<DialogTitle sx={{ fontSize: '1rem' }} id='alert-dialog-title'>
Flowise Version
</DialogTitle>
<DialogContent>
{data && (
<TableContainer component={Paper}>
<Table aria-label='simple table'>
<TableHead>
<TableRow>
<TableCell>Latest Version</TableCell>
<TableCell>Published At</TableCell>
</TableRow>
</TableHead>
<TableBody>
<TableRow sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
<TableCell component='th' scope='row'>
<a target='_blank' rel='noreferrer' href={data.html_url}>
{data.name}
</a>
</TableCell>
<TableCell>{moment(data.published_at).fromNow()}</TableCell>
</TableRow>
</TableBody>
</Table>
</TableContainer>
)}
</DialogContent>
</Dialog>
) : null
return createPortal(component, portalElement)
}
AboutDialog.propTypes = {
show: PropTypes.bool,
onCancel: PropTypes.func
}
export default AboutDialog

View File

@ -0,0 +1,57 @@
import { createPortal } from 'react-dom'
import { useState, useEffect } from 'react'
import { useSelector } from 'react-redux'
import PropTypes from 'prop-types'
import { Dialog, DialogContent, DialogTitle } from '@mui/material'
import ReactJson from 'react-json-view'
const SourceDocDialog = ({ show, dialogProps, onCancel }) => {
const portalElement = document.getElementById('portal')
const customization = useSelector((state) => state.customization)
const [data, setData] = useState({})
useEffect(() => {
if (dialogProps.data) setData(dialogProps.data)
return () => {
setData({})
}
}, [dialogProps])
const component = show ? (
<Dialog
onClose={onCancel}
open={show}
fullWidth
maxWidth='sm'
aria-labelledby='alert-dialog-title'
aria-describedby='alert-dialog-description'
>
<DialogTitle sx={{ fontSize: '1rem' }} id='alert-dialog-title'>
Source Document
</DialogTitle>
<DialogContent>
<ReactJson
theme={customization.isDarkMode ? 'ocean' : 'rjv-default'}
style={{ padding: 10, borderRadius: 10 }}
src={data}
name={null}
quotesOnKeys={false}
enableClipboard={false}
displayDataTypes={false}
/>
</DialogContent>
</Dialog>
) : null
return createPortal(component, portalElement)
}
SourceDocDialog.propTypes = {
show: PropTypes.bool,
dialogProps: PropTypes.object,
onCancel: PropTypes.func
}
export default SourceDocDialog

View File

@ -7,13 +7,14 @@ import rehypeMathjax from 'rehype-mathjax'
import remarkGfm from 'remark-gfm'
import remarkMath from 'remark-math'
import { CircularProgress, OutlinedInput, Divider, InputAdornment, IconButton, Box } from '@mui/material'
import { CircularProgress, OutlinedInput, Divider, InputAdornment, IconButton, Box, Chip } from '@mui/material'
import { useTheme } from '@mui/material/styles'
import { IconSend } from '@tabler/icons'
// project import
import { CodeBlock } from 'ui-component/markdown/CodeBlock'
import { MemoizedReactMarkdown } from 'ui-component/markdown/MemoizedReactMarkdown'
import SourceDocDialog from 'ui-component/dialog/SourceDocDialog'
import './ChatMessage.css'
// api
@ -43,11 +44,18 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
])
const [socketIOClientId, setSocketIOClientId] = useState('')
const [isChatFlowAvailableToStream, setIsChatFlowAvailableToStream] = useState(false)
const [sourceDialogOpen, setSourceDialogOpen] = useState(false)
const [sourceDialogProps, setSourceDialogProps] = useState({})
const inputRef = useRef(null)
const getChatmessageApi = useApi(chatmessageApi.getChatmessageFromChatflow)
const getIsChatflowStreamingApi = useApi(chatflowsApi.getIsChatflowStreaming)
const onSourceDialogClick = (data) => {
setSourceDialogProps({ data })
setSourceDialogOpen(true)
}
const scrollToBottom = () => {
if (ps.current) {
ps.current.scrollTo({ top: maxScroll })
@ -56,13 +64,14 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
const onChange = useCallback((e) => setUserInput(e.target.value), [setUserInput])
const addChatMessage = async (message, type) => {
const addChatMessage = async (message, type, sourceDocuments) => {
try {
const newChatMessageBody = {
role: type,
content: message,
chatflowid: chatflowid
}
if (sourceDocuments) newChatMessageBody.sourceDocuments = JSON.stringify(sourceDocuments)
await chatmessageApi.createNewChatmessage(chatflowid, newChatMessageBody)
} catch (error) {
console.error(error)
@ -78,6 +87,15 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
})
}
const updateLastMessageSourceDocuments = (sourceDocuments) => {
setMessages((prevMessages) => {
let allMessages = [...cloneDeep(prevMessages)]
if (allMessages[allMessages.length - 1].type === 'userMessage') return allMessages
allMessages[allMessages.length - 1].sourceDocuments = sourceDocuments
return allMessages
})
}
// Handle errors
const handleError = (message = 'Oops! There seems to be an error. Please try again.') => {
message = message.replace(`Unable to parse JSON response from chat agent.\n\n`, '')
@ -114,8 +132,20 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
if (response.data) {
const data = response.data
if (!isChatFlowAvailableToStream) setMessages((prevMessages) => [...prevMessages, { message: data, type: 'apiMessage' }])
addChatMessage(data, 'apiMessage')
if (typeof data === 'object' && data.text && data.sourceDocuments) {
if (!isChatFlowAvailableToStream) {
setMessages((prevMessages) => [
...prevMessages,
{ message: data.text, sourceDocuments: data.sourceDocuments, type: 'apiMessage' }
])
}
addChatMessage(data.text, 'apiMessage', data.sourceDocuments)
} else {
if (!isChatFlowAvailableToStream) {
setMessages((prevMessages) => [...prevMessages, { message: data, type: 'apiMessage' }])
}
addChatMessage(data, 'apiMessage')
}
setLoading(false)
setUserInput('')
setTimeout(() => {
@ -146,10 +176,12 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
if (getChatmessageApi.data) {
const loadedMessages = []
for (const message of getChatmessageApi.data) {
loadedMessages.push({
const obj = {
message: message.content,
type: message.role
})
}
if (message.sourceDocuments) obj.sourceDocuments = JSON.parse(message.sourceDocuments)
loadedMessages.push(obj)
}
setMessages((prevMessages) => [...prevMessages, ...loadedMessages])
}
@ -196,6 +228,8 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
setMessages((prevMessages) => [...prevMessages, { message: '', type: 'apiMessage' }])
})
socket.on('sourceDocuments', updateLastMessageSourceDocuments)
socket.on('token', updateLastMessage)
}
@ -225,69 +259,91 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
messages.map((message, index) => {
return (
// The latest message sent by the user will be animated while waiting for a response
<Box
sx={{
background: message.type === 'apiMessage' ? theme.palette.asyncSelect.main : ''
}}
key={index}
style={{ display: 'flex' }}
className={
message.type === 'userMessage' && loading && index === messages.length - 1
? customization.isDarkMode
? 'usermessagewaiting-dark'
: 'usermessagewaiting-light'
: message.type === 'usermessagewaiting'
? 'apimessage'
: 'usermessage'
}
>
{/* Display the correct icon depending on the message type */}
{message.type === 'apiMessage' ? (
<img
src='https://raw.githubusercontent.com/zahidkhawaja/langchain-chat-nextjs/main/public/parroticon.png'
alt='AI'
width='30'
height='30'
className='boticon'
/>
) : (
<img
src='https://raw.githubusercontent.com/zahidkhawaja/langchain-chat-nextjs/main/public/usericon.png'
alt='Me'
width='30'
height='30'
className='usericon'
/>
)}
<div className='markdownanswer'>
{/* Messages are being rendered in Markdown format */}
<MemoizedReactMarkdown
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[rehypeMathjax]}
components={{
code({ inline, className, children, ...props }) {
const match = /language-(\w+)/.exec(className || '')
return !inline ? (
<CodeBlock
key={Math.random()}
chatflowid={chatflowid}
isDialog={isDialog}
language={(match && match[1]) || ''}
value={String(children).replace(/\n$/, '')}
{...props}
/>
) : (
<code className={className} {...props}>
{children}
</code>
)
}
}}
>
{message.message}
</MemoizedReactMarkdown>
</div>
</Box>
<>
<Box
sx={{
background: message.type === 'apiMessage' ? theme.palette.asyncSelect.main : ''
}}
key={index}
style={{ display: 'flex' }}
className={
message.type === 'userMessage' && loading && index === messages.length - 1
? customization.isDarkMode
? 'usermessagewaiting-dark'
: 'usermessagewaiting-light'
: message.type === 'usermessagewaiting'
? 'apimessage'
: 'usermessage'
}
>
{/* Display the correct icon depending on the message type */}
{message.type === 'apiMessage' ? (
<img
src='https://raw.githubusercontent.com/zahidkhawaja/langchain-chat-nextjs/main/public/parroticon.png'
alt='AI'
width='30'
height='30'
className='boticon'
/>
) : (
<img
src='https://raw.githubusercontent.com/zahidkhawaja/langchain-chat-nextjs/main/public/usericon.png'
alt='Me'
width='30'
height='30'
className='usericon'
/>
)}
<div style={{ display: 'flex', flexDirection: 'column', width: '100%' }}>
<div className='markdownanswer'>
{/* Messages are being rendered in Markdown format */}
<MemoizedReactMarkdown
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[rehypeMathjax]}
components={{
code({ inline, className, children, ...props }) {
const match = /language-(\w+)/.exec(className || '')
return !inline ? (
<CodeBlock
key={Math.random()}
chatflowid={chatflowid}
isDialog={isDialog}
language={(match && match[1]) || ''}
value={String(children).replace(/\n$/, '')}
{...props}
/>
) : (
<code className={className} {...props}>
{children}
</code>
)
}
}}
>
{message.message}
</MemoizedReactMarkdown>
</div>
{message.sourceDocuments && (
<div style={{ display: 'block', flexDirection: 'row', width: '100%' }}>
{message.sourceDocuments.map((source, index) => {
return (
<Chip
size='small'
key={index}
label={`${source.pageContent.substring(0, 15)}...`}
component='a'
sx={{ mr: 1, mb: 1 }}
variant='outlined'
clickable
onClick={() => onSourceDialogClick(source)}
/>
)
})}
</div>
)}
</div>
</Box>
</>
)
})}
</div>
@ -328,6 +384,7 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
</form>
</div>
</div>
<SourceDocDialog show={sourceDialogOpen} dialogProps={sourceDialogProps} onCancel={() => setSourceDialogOpen(false)} />
</>
)
}