add return source documents functioanality
This commit is contained in:
parent
ab00214ec2
commit
b071790a5a
|
|
@ -2,7 +2,9 @@ import { BaseLanguageModel } from 'langchain/base_language'
|
||||||
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
|
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
|
||||||
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
|
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
|
||||||
import { ConversationalRetrievalQAChain } from 'langchain/chains'
|
import { ConversationalRetrievalQAChain } from 'langchain/chains'
|
||||||
import { BaseRetriever } from 'langchain/schema'
|
import { AIChatMessage, BaseRetriever, HumanChatMessage } from 'langchain/schema'
|
||||||
|
import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory'
|
||||||
|
import { PromptTemplate } from 'langchain/prompts'
|
||||||
|
|
||||||
const default_qa_template = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
const default_qa_template = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||||
|
|
||||||
|
|
@ -47,6 +49,12 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||||
name: 'vectorStoreRetriever',
|
name: 'vectorStoreRetriever',
|
||||||
type: 'BaseRetriever'
|
type: 'BaseRetriever'
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
label: 'Return Source Documents',
|
||||||
|
name: 'returnSourceDocuments',
|
||||||
|
type: 'boolean',
|
||||||
|
optional: true
|
||||||
|
},
|
||||||
{
|
{
|
||||||
label: 'System Message',
|
label: 'System Message',
|
||||||
name: 'systemMessagePrompt',
|
name: 'systemMessagePrompt',
|
||||||
|
|
@ -56,6 +64,31 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||||
optional: true,
|
optional: true,
|
||||||
placeholder:
|
placeholder:
|
||||||
'I want you to act as a document that I am having a conversation with. Your name is "AI Assistant". You will provide me with answers from the given info. If the answer is not included, say exactly "Hmm, I am not sure." and stop after that. Refuse to answer any question not about the info. Never break character.'
|
'I want you to act as a document that I am having a conversation with. Your name is "AI Assistant". You will provide me with answers from the given info. If the answer is not included, say exactly "Hmm, I am not sure." and stop after that. Refuse to answer any question not about the info. Never break character.'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Chain Option',
|
||||||
|
name: 'chainOption',
|
||||||
|
type: 'options',
|
||||||
|
options: [
|
||||||
|
{
|
||||||
|
label: 'MapReduceDocumentsChain',
|
||||||
|
name: 'map_reduce',
|
||||||
|
description:
|
||||||
|
'Suitable for QA tasks over larger documents and can run the preprocessing step in parallel, reducing the running time'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'RefineDocumentsChain',
|
||||||
|
name: 'refine',
|
||||||
|
description: 'Suitable for QA tasks over a large number of documents.'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'StuffDocumentsChain',
|
||||||
|
name: 'stuff',
|
||||||
|
description: 'Suitable for QA tasks over a small number of documents.'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
additionalParams: true,
|
||||||
|
optional: true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
@ -64,44 +97,64 @@ class ConversationalRetrievalQAChain_Chains implements INode {
|
||||||
const model = nodeData.inputs?.model as BaseLanguageModel
|
const model = nodeData.inputs?.model as BaseLanguageModel
|
||||||
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever
|
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever
|
||||||
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
|
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
|
||||||
|
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
|
||||||
|
const chainOption = nodeData.inputs?.chainOption as string
|
||||||
|
|
||||||
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, {
|
const obj: any = {
|
||||||
verbose: process.env.DEBUG === 'true' ? true : false,
|
verbose: process.env.DEBUG === 'true' ? true : false,
|
||||||
qaTemplate: systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template
|
qaChainOptions: {
|
||||||
|
type: 'stuff',
|
||||||
|
prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template)
|
||||||
|
},
|
||||||
|
memory: new BufferMemory({
|
||||||
|
memoryKey: 'chat_history',
|
||||||
|
inputKey: 'question',
|
||||||
|
outputKey: 'text',
|
||||||
|
returnMessages: true
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
|
||||||
|
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }
|
||||||
|
|
||||||
|
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj)
|
||||||
return chain
|
return chain
|
||||||
}
|
}
|
||||||
|
|
||||||
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
|
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
|
||||||
const chain = nodeData.instance as ConversationalRetrievalQAChain
|
const chain = nodeData.instance as ConversationalRetrievalQAChain
|
||||||
|
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
|
||||||
let model = nodeData.inputs?.model
|
let model = nodeData.inputs?.model
|
||||||
|
|
||||||
// Temporary fix: https://github.com/hwchase17/langchainjs/issues/754
|
// Temporary fix: https://github.com/hwchase17/langchainjs/issues/754
|
||||||
model.streaming = false
|
model.streaming = false
|
||||||
chain.questionGeneratorChain.llm = model
|
chain.questionGeneratorChain.llm = model
|
||||||
|
|
||||||
let chatHistory = ''
|
const obj = { question: input }
|
||||||
|
|
||||||
if (options && options.chatHistory) {
|
if (chain.memory && options && options.chatHistory) {
|
||||||
|
const chatHistory = []
|
||||||
const histories: IMessage[] = options.chatHistory
|
const histories: IMessage[] = options.chatHistory
|
||||||
chatHistory = histories
|
const memory = chain.memory as BaseChatMemory
|
||||||
.map((item) => {
|
|
||||||
return item.message
|
|
||||||
})
|
|
||||||
.join('')
|
|
||||||
}
|
|
||||||
|
|
||||||
const obj = {
|
for (const message of histories) {
|
||||||
question: input,
|
if (message.type === 'apiMessage') {
|
||||||
chat_history: chatHistory ? chatHistory : []
|
chatHistory.push(new AIChatMessage(message.message))
|
||||||
|
} else if (message.type === 'userMessage') {
|
||||||
|
chatHistory.push(new HumanChatMessage(message.message))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
memory.chatHistory = new ChatMessageHistory(chatHistory)
|
||||||
|
chain.memory = memory
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.socketIO && options.socketIOClientId) {
|
if (options.socketIO && options.socketIOClientId) {
|
||||||
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
|
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, undefined, returnSourceDocuments)
|
||||||
const res = await chain.call(obj, [handler])
|
const res = await chain.call(obj, [handler])
|
||||||
|
if (res.text && res.sourceDocuments) return res
|
||||||
return res?.text
|
return res?.text
|
||||||
} else {
|
} else {
|
||||||
const res = await chain.call(obj)
|
const res = await chain.call(obj)
|
||||||
|
if (res.text && res.sourceDocuments) return res
|
||||||
return res?.text
|
return res?.text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@
|
||||||
"faiss-node": "^0.2.1",
|
"faiss-node": "^0.2.1",
|
||||||
"form-data": "^4.0.0",
|
"form-data": "^4.0.0",
|
||||||
"graphql": "^16.6.0",
|
"graphql": "^16.6.0",
|
||||||
"langchain": "^0.0.84",
|
"langchain": "^0.0.91",
|
||||||
"linkifyjs": "^4.1.1",
|
"linkifyjs": "^4.1.1",
|
||||||
"mammoth": "^1.5.1",
|
"mammoth": "^1.5.1",
|
||||||
"moment": "^2.29.3",
|
"moment": "^2.29.3",
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ export interface INode extends INodeProperties {
|
||||||
inputs?: INodeParams[]
|
inputs?: INodeParams[]
|
||||||
output?: INodeOutputsValue[]
|
output?: INodeOutputsValue[]
|
||||||
init?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<any>
|
init?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<any>
|
||||||
run?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<string>
|
run?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<string | ICommonObject>
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface INodeData extends INodeProperties {
|
export interface INodeData extends INodeProperties {
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import * as fs from 'fs'
|
||||||
import * as path from 'path'
|
import * as path from 'path'
|
||||||
import { BaseCallbackHandler } from 'langchain/callbacks'
|
import { BaseCallbackHandler } from 'langchain/callbacks'
|
||||||
import { Server } from 'socket.io'
|
import { Server } from 'socket.io'
|
||||||
|
import { ChainValues } from 'langchain/dist/schema'
|
||||||
|
|
||||||
export const numberOrExpressionRegex = '^(\\d+\\.?\\d*|{{.*}})$' //return true if string consists only numbers OR expression {{}}
|
export const numberOrExpressionRegex = '^(\\d+\\.?\\d*|{{.*}})$' //return true if string consists only numbers OR expression {{}}
|
||||||
export const notEmptyRegex = '(.|\\s)*\\S(.|\\s)*' //return true if string is not empty or blank
|
export const notEmptyRegex = '(.|\\s)*\\S(.|\\s)*' //return true if string is not empty or blank
|
||||||
|
|
@ -208,12 +209,14 @@ export class CustomChainHandler extends BaseCallbackHandler {
|
||||||
socketIO: Server
|
socketIO: Server
|
||||||
socketIOClientId = ''
|
socketIOClientId = ''
|
||||||
skipK = 0 // Skip streaming for first K numbers of handleLLMStart
|
skipK = 0 // Skip streaming for first K numbers of handleLLMStart
|
||||||
|
returnSourceDocuments = false
|
||||||
|
|
||||||
constructor(socketIO: Server, socketIOClientId: string, skipK?: number) {
|
constructor(socketIO: Server, socketIOClientId: string, skipK?: number, returnSourceDocuments?: boolean) {
|
||||||
super()
|
super()
|
||||||
this.socketIO = socketIO
|
this.socketIO = socketIO
|
||||||
this.socketIOClientId = socketIOClientId
|
this.socketIOClientId = socketIOClientId
|
||||||
this.skipK = skipK ?? this.skipK
|
this.skipK = skipK ?? this.skipK
|
||||||
|
this.returnSourceDocuments = returnSourceDocuments ?? this.returnSourceDocuments
|
||||||
}
|
}
|
||||||
|
|
||||||
handleLLMStart() {
|
handleLLMStart() {
|
||||||
|
|
@ -233,4 +236,10 @@ export class CustomChainHandler extends BaseCallbackHandler {
|
||||||
handleLLMEnd() {
|
handleLLMEnd() {
|
||||||
this.socketIO.to(this.socketIOClientId).emit('end')
|
this.socketIO.to(this.socketIOClientId).emit('end')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handleChainEnd(outputs: ChainValues): void | Promise<void> {
|
||||||
|
if (this.returnSourceDocuments) {
|
||||||
|
this.socketIO.to(this.socketIOClientId).emit('sourceDocuments', outputs?.sourceDocuments)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ export interface IChatMessage {
|
||||||
content: string
|
content: string
|
||||||
chatflowid: string
|
chatflowid: string
|
||||||
createdDate: Date
|
createdDate: Date
|
||||||
|
sourceDocuments: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface IComponentNodes {
|
export interface IComponentNodes {
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,9 @@ export class ChatMessage implements IChatMessage {
|
||||||
@Column()
|
@Column()
|
||||||
content: string
|
content: string
|
||||||
|
|
||||||
|
@Column({ nullable: true })
|
||||||
|
sourceDocuments: string
|
||||||
|
|
||||||
@CreateDateColumn()
|
@CreateDateColumn()
|
||||||
createdDate: Date
|
createdDate: Date
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ export class App {
|
||||||
const basicAuthMiddleware = basicAuth({
|
const basicAuthMiddleware = basicAuth({
|
||||||
users: { [username]: password }
|
users: { [username]: password }
|
||||||
})
|
})
|
||||||
const whitelistURLs = ['/api/v1/prediction/', '/api/v1/node-icon/']
|
const whitelistURLs = ['/api/v1/prediction/', '/api/v1/node-icon/', '/api/v1/chatflows-streaming']
|
||||||
this.app.use((req, res, next) => {
|
this.app.use((req, res, next) => {
|
||||||
if (req.url.includes('/api/v1/')) {
|
if (req.url.includes('/api/v1/')) {
|
||||||
whitelistURLs.some((url) => req.url.includes(url)) ? next() : basicAuthMiddleware(req, res, next)
|
whitelistURLs.some((url) => req.url.includes(url)) ? next() : basicAuthMiddleware(req, res, next)
|
||||||
|
|
|
||||||
|
|
@ -27,9 +27,10 @@ import PerfectScrollbar from 'react-perfect-scrollbar'
|
||||||
import MainCard from 'ui-component/cards/MainCard'
|
import MainCard from 'ui-component/cards/MainCard'
|
||||||
import Transitions from 'ui-component/extended/Transitions'
|
import Transitions from 'ui-component/extended/Transitions'
|
||||||
import { BackdropLoader } from 'ui-component/loading/BackdropLoader'
|
import { BackdropLoader } from 'ui-component/loading/BackdropLoader'
|
||||||
|
import AboutDialog from 'ui-component/dialog/AboutDialog'
|
||||||
|
|
||||||
// assets
|
// assets
|
||||||
import { IconLogout, IconSettings, IconFileExport, IconFileDownload } from '@tabler/icons'
|
import { IconLogout, IconSettings, IconFileExport, IconFileDownload, IconInfoCircle } from '@tabler/icons'
|
||||||
|
|
||||||
// API
|
// API
|
||||||
import databaseApi from 'api/database'
|
import databaseApi from 'api/database'
|
||||||
|
|
@ -49,6 +50,7 @@ const ProfileSection = ({ username, handleLogout }) => {
|
||||||
|
|
||||||
const [open, setOpen] = useState(false)
|
const [open, setOpen] = useState(false)
|
||||||
const [loading, setLoading] = useState(false)
|
const [loading, setLoading] = useState(false)
|
||||||
|
const [aboutDialogOpen, setAboutDialogOpen] = useState(false)
|
||||||
|
|
||||||
const anchorRef = useRef(null)
|
const anchorRef = useRef(null)
|
||||||
const uploadRef = useRef(null)
|
const uploadRef = useRef(null)
|
||||||
|
|
@ -215,6 +217,18 @@ const ProfileSection = ({ username, handleLogout }) => {
|
||||||
</ListItemIcon>
|
</ListItemIcon>
|
||||||
<ListItemText primary={<Typography variant='body2'>Export Database</Typography>} />
|
<ListItemText primary={<Typography variant='body2'>Export Database</Typography>} />
|
||||||
</ListItemButton>
|
</ListItemButton>
|
||||||
|
<ListItemButton
|
||||||
|
sx={{ borderRadius: `${customization.borderRadius}px` }}
|
||||||
|
onClick={() => {
|
||||||
|
setOpen(false)
|
||||||
|
setAboutDialogOpen(true)
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<ListItemIcon>
|
||||||
|
<IconInfoCircle stroke={1.5} size='1.3rem' />
|
||||||
|
</ListItemIcon>
|
||||||
|
<ListItemText primary={<Typography variant='body2'>About Flowise</Typography>} />
|
||||||
|
</ListItemButton>
|
||||||
{localStorage.getItem('username') && localStorage.getItem('password') && (
|
{localStorage.getItem('username') && localStorage.getItem('password') && (
|
||||||
<ListItemButton
|
<ListItemButton
|
||||||
sx={{ borderRadius: `${customization.borderRadius}px` }}
|
sx={{ borderRadius: `${customization.borderRadius}px` }}
|
||||||
|
|
@ -237,6 +251,7 @@ const ProfileSection = ({ username, handleLogout }) => {
|
||||||
</Popper>
|
</Popper>
|
||||||
<input ref={uploadRef} type='file' hidden accept='.json' onChange={(e) => handleFileUpload(e)} />
|
<input ref={uploadRef} type='file' hidden accept='.json' onChange={(e) => handleFileUpload(e)} />
|
||||||
<BackdropLoader open={loading} />
|
<BackdropLoader open={loading} />
|
||||||
|
<AboutDialog show={aboutDialogOpen} onCancel={() => setAboutDialogOpen(false)} />
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,85 @@
|
||||||
|
import { createPortal } from 'react-dom'
|
||||||
|
import { useState, useEffect } from 'react'
|
||||||
|
import PropTypes from 'prop-types'
|
||||||
|
import { Dialog, DialogContent, DialogTitle, TableContainer, Table, TableHead, TableRow, TableCell, TableBody, Paper } from '@mui/material'
|
||||||
|
import moment from 'moment'
|
||||||
|
import axios from 'axios'
|
||||||
|
|
||||||
|
const fetchLatestVer = async ({ api }) => {
|
||||||
|
let apiReturn = await axios
|
||||||
|
.get(api)
|
||||||
|
.then(async function (response) {
|
||||||
|
return response.data
|
||||||
|
})
|
||||||
|
.catch(function (error) {
|
||||||
|
console.error(error)
|
||||||
|
})
|
||||||
|
return apiReturn
|
||||||
|
}
|
||||||
|
|
||||||
|
const AboutDialog = ({ show, onCancel }) => {
|
||||||
|
const portalElement = document.getElementById('portal')
|
||||||
|
|
||||||
|
const [data, setData] = useState({})
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (show) {
|
||||||
|
const fetchData = async (api) => {
|
||||||
|
let response = await fetchLatestVer({ api })
|
||||||
|
setData(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
fetchData('https://api.github.com/repos/FlowiseAI/Flowise/releases/latest')
|
||||||
|
}
|
||||||
|
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [show])
|
||||||
|
|
||||||
|
const component = show ? (
|
||||||
|
<Dialog
|
||||||
|
onClose={onCancel}
|
||||||
|
open={show}
|
||||||
|
fullWidth
|
||||||
|
maxWidth='sm'
|
||||||
|
aria-labelledby='alert-dialog-title'
|
||||||
|
aria-describedby='alert-dialog-description'
|
||||||
|
>
|
||||||
|
<DialogTitle sx={{ fontSize: '1rem' }} id='alert-dialog-title'>
|
||||||
|
Flowise Version
|
||||||
|
</DialogTitle>
|
||||||
|
<DialogContent>
|
||||||
|
{data && (
|
||||||
|
<TableContainer component={Paper}>
|
||||||
|
<Table aria-label='simple table'>
|
||||||
|
<TableHead>
|
||||||
|
<TableRow>
|
||||||
|
<TableCell>Latest Version</TableCell>
|
||||||
|
<TableCell>Published At</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
</TableHead>
|
||||||
|
<TableBody>
|
||||||
|
<TableRow sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||||
|
<TableCell component='th' scope='row'>
|
||||||
|
<a target='_blank' rel='noreferrer' href={data.html_url}>
|
||||||
|
{data.name}
|
||||||
|
</a>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell>{moment(data.published_at).fromNow()}</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
</TableBody>
|
||||||
|
</Table>
|
||||||
|
</TableContainer>
|
||||||
|
)}
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
) : null
|
||||||
|
|
||||||
|
return createPortal(component, portalElement)
|
||||||
|
}
|
||||||
|
|
||||||
|
AboutDialog.propTypes = {
|
||||||
|
show: PropTypes.bool,
|
||||||
|
onCancel: PropTypes.func
|
||||||
|
}
|
||||||
|
|
||||||
|
export default AboutDialog
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
import { createPortal } from 'react-dom'
|
||||||
|
import { useState, useEffect } from 'react'
|
||||||
|
import { useSelector } from 'react-redux'
|
||||||
|
import PropTypes from 'prop-types'
|
||||||
|
import { Dialog, DialogContent, DialogTitle } from '@mui/material'
|
||||||
|
import ReactJson from 'react-json-view'
|
||||||
|
|
||||||
|
const SourceDocDialog = ({ show, dialogProps, onCancel }) => {
|
||||||
|
const portalElement = document.getElementById('portal')
|
||||||
|
const customization = useSelector((state) => state.customization)
|
||||||
|
|
||||||
|
const [data, setData] = useState({})
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (dialogProps.data) setData(dialogProps.data)
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
setData({})
|
||||||
|
}
|
||||||
|
}, [dialogProps])
|
||||||
|
|
||||||
|
const component = show ? (
|
||||||
|
<Dialog
|
||||||
|
onClose={onCancel}
|
||||||
|
open={show}
|
||||||
|
fullWidth
|
||||||
|
maxWidth='sm'
|
||||||
|
aria-labelledby='alert-dialog-title'
|
||||||
|
aria-describedby='alert-dialog-description'
|
||||||
|
>
|
||||||
|
<DialogTitle sx={{ fontSize: '1rem' }} id='alert-dialog-title'>
|
||||||
|
Source Document
|
||||||
|
</DialogTitle>
|
||||||
|
<DialogContent>
|
||||||
|
<ReactJson
|
||||||
|
theme={customization.isDarkMode ? 'ocean' : 'rjv-default'}
|
||||||
|
style={{ padding: 10, borderRadius: 10 }}
|
||||||
|
src={data}
|
||||||
|
name={null}
|
||||||
|
quotesOnKeys={false}
|
||||||
|
enableClipboard={false}
|
||||||
|
displayDataTypes={false}
|
||||||
|
/>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
) : null
|
||||||
|
|
||||||
|
return createPortal(component, portalElement)
|
||||||
|
}
|
||||||
|
|
||||||
|
SourceDocDialog.propTypes = {
|
||||||
|
show: PropTypes.bool,
|
||||||
|
dialogProps: PropTypes.object,
|
||||||
|
onCancel: PropTypes.func
|
||||||
|
}
|
||||||
|
|
||||||
|
export default SourceDocDialog
|
||||||
|
|
@ -7,13 +7,14 @@ import rehypeMathjax from 'rehype-mathjax'
|
||||||
import remarkGfm from 'remark-gfm'
|
import remarkGfm from 'remark-gfm'
|
||||||
import remarkMath from 'remark-math'
|
import remarkMath from 'remark-math'
|
||||||
|
|
||||||
import { CircularProgress, OutlinedInput, Divider, InputAdornment, IconButton, Box } from '@mui/material'
|
import { CircularProgress, OutlinedInput, Divider, InputAdornment, IconButton, Box, Chip } from '@mui/material'
|
||||||
import { useTheme } from '@mui/material/styles'
|
import { useTheme } from '@mui/material/styles'
|
||||||
import { IconSend } from '@tabler/icons'
|
import { IconSend } from '@tabler/icons'
|
||||||
|
|
||||||
// project import
|
// project import
|
||||||
import { CodeBlock } from 'ui-component/markdown/CodeBlock'
|
import { CodeBlock } from 'ui-component/markdown/CodeBlock'
|
||||||
import { MemoizedReactMarkdown } from 'ui-component/markdown/MemoizedReactMarkdown'
|
import { MemoizedReactMarkdown } from 'ui-component/markdown/MemoizedReactMarkdown'
|
||||||
|
import SourceDocDialog from 'ui-component/dialog/SourceDocDialog'
|
||||||
import './ChatMessage.css'
|
import './ChatMessage.css'
|
||||||
|
|
||||||
// api
|
// api
|
||||||
|
|
@ -43,11 +44,18 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
|
||||||
])
|
])
|
||||||
const [socketIOClientId, setSocketIOClientId] = useState('')
|
const [socketIOClientId, setSocketIOClientId] = useState('')
|
||||||
const [isChatFlowAvailableToStream, setIsChatFlowAvailableToStream] = useState(false)
|
const [isChatFlowAvailableToStream, setIsChatFlowAvailableToStream] = useState(false)
|
||||||
|
const [sourceDialogOpen, setSourceDialogOpen] = useState(false)
|
||||||
|
const [sourceDialogProps, setSourceDialogProps] = useState({})
|
||||||
|
|
||||||
const inputRef = useRef(null)
|
const inputRef = useRef(null)
|
||||||
const getChatmessageApi = useApi(chatmessageApi.getChatmessageFromChatflow)
|
const getChatmessageApi = useApi(chatmessageApi.getChatmessageFromChatflow)
|
||||||
const getIsChatflowStreamingApi = useApi(chatflowsApi.getIsChatflowStreaming)
|
const getIsChatflowStreamingApi = useApi(chatflowsApi.getIsChatflowStreaming)
|
||||||
|
|
||||||
|
const onSourceDialogClick = (data) => {
|
||||||
|
setSourceDialogProps({ data })
|
||||||
|
setSourceDialogOpen(true)
|
||||||
|
}
|
||||||
|
|
||||||
const scrollToBottom = () => {
|
const scrollToBottom = () => {
|
||||||
if (ps.current) {
|
if (ps.current) {
|
||||||
ps.current.scrollTo({ top: maxScroll })
|
ps.current.scrollTo({ top: maxScroll })
|
||||||
|
|
@ -56,13 +64,14 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
|
||||||
|
|
||||||
const onChange = useCallback((e) => setUserInput(e.target.value), [setUserInput])
|
const onChange = useCallback((e) => setUserInput(e.target.value), [setUserInput])
|
||||||
|
|
||||||
const addChatMessage = async (message, type) => {
|
const addChatMessage = async (message, type, sourceDocuments) => {
|
||||||
try {
|
try {
|
||||||
const newChatMessageBody = {
|
const newChatMessageBody = {
|
||||||
role: type,
|
role: type,
|
||||||
content: message,
|
content: message,
|
||||||
chatflowid: chatflowid
|
chatflowid: chatflowid
|
||||||
}
|
}
|
||||||
|
if (sourceDocuments) newChatMessageBody.sourceDocuments = JSON.stringify(sourceDocuments)
|
||||||
await chatmessageApi.createNewChatmessage(chatflowid, newChatMessageBody)
|
await chatmessageApi.createNewChatmessage(chatflowid, newChatMessageBody)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error)
|
console.error(error)
|
||||||
|
|
@ -78,6 +87,15 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const updateLastMessageSourceDocuments = (sourceDocuments) => {
|
||||||
|
setMessages((prevMessages) => {
|
||||||
|
let allMessages = [...cloneDeep(prevMessages)]
|
||||||
|
if (allMessages[allMessages.length - 1].type === 'userMessage') return allMessages
|
||||||
|
allMessages[allMessages.length - 1].sourceDocuments = sourceDocuments
|
||||||
|
return allMessages
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Handle errors
|
// Handle errors
|
||||||
const handleError = (message = 'Oops! There seems to be an error. Please try again.') => {
|
const handleError = (message = 'Oops! There seems to be an error. Please try again.') => {
|
||||||
message = message.replace(`Unable to parse JSON response from chat agent.\n\n`, '')
|
message = message.replace(`Unable to parse JSON response from chat agent.\n\n`, '')
|
||||||
|
|
@ -114,8 +132,20 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
|
||||||
|
|
||||||
if (response.data) {
|
if (response.data) {
|
||||||
const data = response.data
|
const data = response.data
|
||||||
if (!isChatFlowAvailableToStream) setMessages((prevMessages) => [...prevMessages, { message: data, type: 'apiMessage' }])
|
if (typeof data === 'object' && data.text && data.sourceDocuments) {
|
||||||
|
if (!isChatFlowAvailableToStream) {
|
||||||
|
setMessages((prevMessages) => [
|
||||||
|
...prevMessages,
|
||||||
|
{ message: data.text, sourceDocuments: data.sourceDocuments, type: 'apiMessage' }
|
||||||
|
])
|
||||||
|
}
|
||||||
|
addChatMessage(data.text, 'apiMessage', data.sourceDocuments)
|
||||||
|
} else {
|
||||||
|
if (!isChatFlowAvailableToStream) {
|
||||||
|
setMessages((prevMessages) => [...prevMessages, { message: data, type: 'apiMessage' }])
|
||||||
|
}
|
||||||
addChatMessage(data, 'apiMessage')
|
addChatMessage(data, 'apiMessage')
|
||||||
|
}
|
||||||
setLoading(false)
|
setLoading(false)
|
||||||
setUserInput('')
|
setUserInput('')
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
|
|
@ -146,10 +176,12 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
|
||||||
if (getChatmessageApi.data) {
|
if (getChatmessageApi.data) {
|
||||||
const loadedMessages = []
|
const loadedMessages = []
|
||||||
for (const message of getChatmessageApi.data) {
|
for (const message of getChatmessageApi.data) {
|
||||||
loadedMessages.push({
|
const obj = {
|
||||||
message: message.content,
|
message: message.content,
|
||||||
type: message.role
|
type: message.role
|
||||||
})
|
}
|
||||||
|
if (message.sourceDocuments) obj.sourceDocuments = JSON.parse(message.sourceDocuments)
|
||||||
|
loadedMessages.push(obj)
|
||||||
}
|
}
|
||||||
setMessages((prevMessages) => [...prevMessages, ...loadedMessages])
|
setMessages((prevMessages) => [...prevMessages, ...loadedMessages])
|
||||||
}
|
}
|
||||||
|
|
@ -196,6 +228,8 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
|
||||||
setMessages((prevMessages) => [...prevMessages, { message: '', type: 'apiMessage' }])
|
setMessages((prevMessages) => [...prevMessages, { message: '', type: 'apiMessage' }])
|
||||||
})
|
})
|
||||||
|
|
||||||
|
socket.on('sourceDocuments', updateLastMessageSourceDocuments)
|
||||||
|
|
||||||
socket.on('token', updateLastMessage)
|
socket.on('token', updateLastMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -225,6 +259,7 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
|
||||||
messages.map((message, index) => {
|
messages.map((message, index) => {
|
||||||
return (
|
return (
|
||||||
// The latest message sent by the user will be animated while waiting for a response
|
// The latest message sent by the user will be animated while waiting for a response
|
||||||
|
<>
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
background: message.type === 'apiMessage' ? theme.palette.asyncSelect.main : ''
|
background: message.type === 'apiMessage' ? theme.palette.asyncSelect.main : ''
|
||||||
|
|
@ -259,6 +294,7 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
|
||||||
className='usericon'
|
className='usericon'
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
<div style={{ display: 'flex', flexDirection: 'column', width: '100%' }}>
|
||||||
<div className='markdownanswer'>
|
<div className='markdownanswer'>
|
||||||
{/* Messages are being rendered in Markdown format */}
|
{/* Messages are being rendered in Markdown format */}
|
||||||
<MemoizedReactMarkdown
|
<MemoizedReactMarkdown
|
||||||
|
|
@ -287,7 +323,27 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
|
||||||
{message.message}
|
{message.message}
|
||||||
</MemoizedReactMarkdown>
|
</MemoizedReactMarkdown>
|
||||||
</div>
|
</div>
|
||||||
|
{message.sourceDocuments && (
|
||||||
|
<div style={{ display: 'block', flexDirection: 'row', width: '100%' }}>
|
||||||
|
{message.sourceDocuments.map((source, index) => {
|
||||||
|
return (
|
||||||
|
<Chip
|
||||||
|
size='small'
|
||||||
|
key={index}
|
||||||
|
label={`${source.pageContent.substring(0, 15)}...`}
|
||||||
|
component='a'
|
||||||
|
sx={{ mr: 1, mb: 1 }}
|
||||||
|
variant='outlined'
|
||||||
|
clickable
|
||||||
|
onClick={() => onSourceDialogClick(source)}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
</Box>
|
</Box>
|
||||||
|
</>
|
||||||
)
|
)
|
||||||
})}
|
})}
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -328,6 +384,7 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {
|
||||||
</form>
|
</form>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<SourceDocDialog show={sourceDialogOpen} dialogProps={sourceDialogProps} onCancel={() => setSourceDialogOpen(false)} />
|
||||||
</>
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue