diff --git a/packages/components/src/Interface.ts b/packages/components/src/Interface.ts index 811090481..72b657c29 100644 --- a/packages/components/src/Interface.ts +++ b/packages/components/src/Interface.ts @@ -442,9 +442,9 @@ export interface IServerSideEventStreamer { streamEndEvent(chatId: string): void streamUsageMetadataEvent(chatId: string, data: any): void streamAudioEvent(chatId: string, audioData: string): void - streamTTSStartEvent(chatId: string, format: string): void - streamTTSDataEvent(chatId: string, audioChunk: string): void - streamTTSEndEvent(chatId: string): void + streamTTSStartEvent(chatId: string, chatMessageId: string, format: string): void + streamTTSDataEvent(chatId: string, chatMessageId: string, audioChunk: string): void + streamTTSEndEvent(chatId: string, chatMessageId: string): void } export enum FollowUpPromptProvider { diff --git a/packages/server/src/controllers/internal-predictions/index.ts b/packages/server/src/controllers/internal-predictions/index.ts index 89bcefbb4..48ace57fb 100644 --- a/packages/server/src/controllers/internal-predictions/index.ts +++ b/packages/server/src/controllers/internal-predictions/index.ts @@ -50,7 +50,14 @@ const createAndStreamInternalPrediction = async (req: Request, res: Response, ne databaseEntities: getRunningExpressApp().AppDataSource?.entityMetadatas || [] } - await generateTTSForResponseStream(apiResponse.text, chatflow.textToSpeech, options, apiResponse.chatId, sseStreamer) + await generateTTSForResponseStream( + apiResponse.text, + chatflow.textToSpeech, + options, + apiResponse.chatId, + apiResponse.chatMessageId, + sseStreamer + ) } } catch (error) { if (chatId) { diff --git a/packages/server/src/controllers/predictions/index.ts b/packages/server/src/controllers/predictions/index.ts index 7616eacaa..47cc27c99 100644 --- a/packages/server/src/controllers/predictions/index.ts +++ b/packages/server/src/controllers/predictions/index.ts @@ -93,6 +93,7 @@ const createPrediction = async (req: Request, res: Response, next: NextFunction) chatflow.textToSpeech, options, apiResponse.chatId, + apiResponse.chatMessageId, sseStreamer ) } diff --git a/packages/server/src/controllers/text-to-speech/index.ts b/packages/server/src/controllers/text-to-speech/index.ts index 3d5a755aa..69a75d5ec 100644 --- a/packages/server/src/controllers/text-to-speech/index.ts +++ b/packages/server/src/controllers/text-to-speech/index.ts @@ -8,7 +8,7 @@ import { databaseEntities } from '../../utils' const generateTextToSpeech = async (req: Request, res: Response) => { try { - const { text, provider, credentialId, voice, model } = req.body + const { chatMessageId, text, provider, credentialId, voice, model } = req.body if (!text) { throw new InternalFlowiseError( @@ -60,7 +60,7 @@ const generateTextToSpeech = async (req: Request, res: Response) => { (format: string) => { const startResponse = { event: 'tts_start', - data: { format } + data: { chatMessageId, format } } res.write('event: tts_start\n') res.write(`data: ${JSON.stringify(startResponse)}\n\n`) @@ -69,7 +69,7 @@ const generateTextToSpeech = async (req: Request, res: Response) => { const audioBase64 = chunk.toString('base64') const clientResponse = { event: 'tts_data', - data: audioBase64 + data: { chatMessageId, audioChunk: audioBase64 } } res.write('event: tts_data\n') res.write(`data: ${JSON.stringify(clientResponse)}\n\n`) @@ -77,7 +77,7 @@ const generateTextToSpeech = async (req: Request, res: Response) => { async () => { const endResponse = { event: 'tts_end', - data: {} + data: { chatMessageId } } res.write('event: tts_end\n') res.write(`data: ${JSON.stringify(endResponse)}\n\n`) diff --git a/packages/server/src/utils/SSEStreamer.ts b/packages/server/src/utils/SSEStreamer.ts index 050e7f4b0..785293284 100644 --- a/packages/server/src/utils/SSEStreamer.ts +++ b/packages/server/src/utils/SSEStreamer.ts @@ -269,34 +269,34 @@ export class SSEStreamer implements IServerSideEventStreamer { } } - streamTTSStartEvent(chatId: string, format: string): void { + streamTTSStartEvent(chatId: string, chatMessageId: string, format: string): void { const client = this.clients[chatId] if (client) { const clientResponse = { event: 'tts_start', - data: { format } + data: { chatMessageId, format } } client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') } } - streamTTSDataEvent(chatId: string, audioChunk: string): void { + streamTTSDataEvent(chatId: string, chatMessageId: string, audioChunk: string): void { const client = this.clients[chatId] if (client) { const clientResponse = { event: 'tts_data', - data: audioChunk + data: { chatMessageId, audioChunk } } client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') } } - streamTTSEndEvent(chatId: string): void { + streamTTSEndEvent(chatId: string, chatMessageId: string): void { const client = this.clients[chatId] if (client) { const clientResponse = { event: 'tts_end', - data: {} + data: { chatMessageId } } client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') } diff --git a/packages/server/src/utils/buildAgentflow.ts b/packages/server/src/utils/buildAgentflow.ts index 7053020c6..464ff850e 100644 --- a/packages/server/src/utils/buildAgentflow.ts +++ b/packages/server/src/utils/buildAgentflow.ts @@ -2049,7 +2049,7 @@ export const executeAgentFlow = async ({ } if (sseStreamer) { - await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, sseStreamer) + await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, chatMessage?.id, sseStreamer) } } diff --git a/packages/server/src/utils/buildChatflow.ts b/packages/server/src/utils/buildChatflow.ts index 652c778a9..78e2c7486 100644 --- a/packages/server/src/utils/buildChatflow.ts +++ b/packages/server/src/utils/buildChatflow.ts @@ -94,6 +94,7 @@ const generateTTSForResponseStream = async ( textToSpeechConfig: string | undefined, options: ICommonObject, chatId: string, + chatMessageId: string, sseStreamer: IServerSideEventStreamer ): Promise => { try { @@ -121,19 +122,19 @@ const generateTTSForResponseStream = async ( activeProviderConfig, options, (format: string) => { - sseStreamer.streamTTSStartEvent(chatId, format) + sseStreamer.streamTTSStartEvent(chatId, chatMessageId, format) }, (chunk: Buffer) => { const audioBase64 = chunk.toString('base64') - sseStreamer.streamTTSDataEvent(chatId, audioBase64) + sseStreamer.streamTTSDataEvent(chatId, chatMessageId, audioBase64) }, () => { - sseStreamer.streamTTSEndEvent(chatId) + sseStreamer.streamTTSEndEvent(chatId, chatMessageId) } ) } catch (error) { logger.error(`[server]: TTS streaming failed: ${getErrorMessage(error)}`) - sseStreamer.streamTTSEndEvent(chatId) + sseStreamer.streamTTSEndEvent(chatId, chatMessageId) } } @@ -902,9 +903,9 @@ export const executeFlow = async ({ } if (streaming && sseStreamer) { - await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, sseStreamer) + await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, chatMessage?.id, sseStreamer) } else if (sseStreamer) { - await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, sseStreamer) + await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, chatMessage?.id, sseStreamer) } } diff --git a/packages/ui/src/views/chatmessage/ChatMessage.jsx b/packages/ui/src/views/chatmessage/ChatMessage.jsx index 8b006f449..b60ed90ff 100644 --- a/packages/ui/src/views/chatmessage/ChatMessage.jsx +++ b/packages/ui/src/views/chatmessage/ChatMessage.jsx @@ -39,7 +39,8 @@ import { IconCheck, IconPaperclip, IconSparkles, - IconVolume + IconVolume, + IconSquare } from '@tabler/icons-react' import robotPNG from '@/assets/images/robot.png' import userPNG from '@/assets/images/account.png' @@ -253,7 +254,8 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP const [isConfigLoading, setIsConfigLoading] = useState(true) // TTS state - const [ttsLoading, setTtsLoading] = useState({}) + const [isTTSLoading, setIsTTSLoading] = useState({}) + const [isTTSPlaying, setIsTTSPlaying] = useState({}) const [ttsAudio, setTtsAudio] = useState({}) const [isTTSEnabled, setIsTTSEnabled] = useState(false) @@ -1053,10 +1055,10 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP handleAutoPlayAudio(payload.data) break case 'tts_start': - handleTTSStart(payload.data.format) + handleTTSStart(payload.data) break case 'tts_data': - handleTTSDataChunk(payload.data) + handleTTSDataChunk(payload.data.audioChunk) break case 'tts_end': handleTTSEnd() @@ -1559,9 +1561,7 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP setIsLeadSaving(false) } - const handleTTSClick = async (messageId, messageText) => { - if (ttsLoading[messageId]) return - + const handleTTSStop = (messageId) => { if (ttsAudio[messageId]) { ttsAudio[messageId].pause() ttsAudio[messageId].currentTime = 0 @@ -1570,14 +1570,38 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP delete newState[messageId] return newState }) + } + + if (ttsStreamingState.audio) { + ttsStreamingState.audio.pause() + cleanupTTSStreaming() + } + + setIsTTSPlaying((prev) => { + const newState = { ...prev } + delete newState[messageId] + return newState + }) + + setIsTTSLoading((prev) => { + const newState = { ...prev } + delete newState[messageId] + return newState + }) + } + + const handleTTSClick = async (messageId, messageText) => { + if (isTTSLoading[messageId]) return + + if (isTTSPlaying[messageId] || ttsAudio[messageId]) { + handleTTSStop(messageId) return } - setTtsLoading((prev) => ({ ...prev, [messageId]: true })) - + handleTTSStart({ chatMessageId: messageId, format: 'mp3' }) try { let ttsConfig = null - if (getChatflowConfig.data && getChatflowConfig.data.textToSpeech) { + if (getChatflowConfig?.data?.textToSpeech) { try { ttsConfig = typeof getChatflowConfig.data.textToSpeech === 'string' @@ -1592,7 +1616,7 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP let providerConfig = null if (ttsConfig) { Object.keys(ttsConfig).forEach((provider) => { - if (ttsConfig[provider] && ttsConfig[provider].status) { + if (ttsConfig?.[provider]?.status) { activeProvider = provider providerConfig = ttsConfig[provider] } @@ -1607,19 +1631,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP return } - // Use existing streaming infrastructure for manual TTS - handleTTSStart('mp3', (audio) => { - setTtsAudio((prev) => ({ ...prev, [messageId]: audio })) - - audio.addEventListener('ended', () => { - setTtsAudio((prev) => { - const newState = { ...prev } - delete newState[messageId] - return newState - }) - }) - }) - const response = await fetch('/api/v1/text-to-speech/generate', { method: 'POST', headers: { @@ -1628,6 +1639,8 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP }, credentials: 'include', body: JSON.stringify({ + chatId: chatId, + chatMessageId: messageId, text: messageText, provider: activeProvider, credentialId: providerConfig.credentialId, @@ -1652,25 +1665,21 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP break } const value = result.value - - // Decode the chunk as text and add to buffer const chunk = decoder.decode(value, { stream: true }) buffer += chunk - // Process complete SSE events const lines = buffer.split('\n\n') - buffer = lines.pop() || '' // Keep incomplete event in buffer + buffer = lines.pop() || '' for (const eventBlock of lines) { if (eventBlock.trim()) { const event = parseSSEEvent(eventBlock) if (event) { - // Handle the event just like the SSE handler does switch (event.event) { case 'tts_start': break case 'tts_data': - handleTTSDataChunk(event.data) + handleTTSDataChunk(event.data.audioChunk) break case 'tts_end': handleTTSEnd() @@ -1689,7 +1698,7 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP options: { variant: 'error' } }) } finally { - setTtsLoading((prev) => { + setIsTTSLoading((prev) => { const newState = { ...prev } delete newState[messageId] return newState @@ -1699,7 +1708,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP const handleAutoPlayAudio = async (audioData) => { try { - // Convert base64 audio data to blob and play const audioBuffer = Uint8Array.from(atob(audioData), (c) => c.charCodeAt(0)) const audioBlob = new Blob([audioBuffer], { type: 'audio/mpeg' }) const audioUrl = URL.createObjectURL(audioBlob) @@ -1712,19 +1720,10 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP await audio.play() } catch (error) { console.error('Error playing auto TTS audio:', error) - // Fallback: Use manual TTS API call - const lastMessage = messages[messages.length - 1] - if (lastMessage && lastMessage.type === 'apiMessage' && lastMessage.message) { - try { - await handleTTSClick(lastMessage.id, lastMessage.message) - } catch (fallbackError) { - console.error('TTS fallback also failed:', fallbackError) - enqueueSnackbar({ - message: 'Auto-play audio failed', - options: { variant: 'error' } - }) - } - } + enqueueSnackbar({ + message: 'Auto-play audio failed', + options: { variant: 'error' } + }) } } @@ -1751,7 +1750,7 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP return event.event ? event : null } - const initializeTTSStreaming = (format, onAudioReady = null) => { + const initializeTTSStreaming = (data) => { try { const mediaSource = new MediaSource() const audio = new Audio() @@ -1759,9 +1758,7 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP mediaSource.addEventListener('sourceopen', () => { try { - // Use the provided format, default to MP3 if not set - const mimeType = format === 'mp3' ? 'audio/mpeg' : 'audio/mpeg' - + const mimeType = data.format === 'mp3' ? 'audio/mpeg' : 'audio/mpeg' const sourceBuffer = mediaSource.addSourceBuffer(mimeType) setTtsStreamingState((prevState) => ({ @@ -1771,16 +1768,9 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP audio })) - // Start playback - audio.play().catch((playError) => { console.error('Error starting audio playback:', playError) }) - - // Notify callback if provided - if (onAudioReady) { - onAudioReady(audio) - } } catch (error) { console.error('Error setting up source buffer:', error) console.error('MediaSource readyState:', mediaSource.readyState) @@ -1788,7 +1778,24 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP } }) + audio.addEventListener('playing', () => { + setIsTTSLoading((prevState) => { + const newState = { ...prevState } + newState[data.chatMessageId] = false + return newState + }) + setIsTTSPlaying((prevState) => ({ + ...prevState, + [data.chatMessageId]: true + })) + }) + audio.addEventListener('ended', () => { + setIsTTSPlaying((prevState) => { + const newState = { ...prevState } + delete newState[data.chatMessageId] + return newState + }) cleanupTTSStreaming() }) } catch (error) { @@ -1850,10 +1857,20 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP }) } - const handleTTSStart = (format, onAudioReady = null) => { - // Store the audio format for this TTS session and initialize + const handleTTSStart = (data) => { + setIsTTSLoading((prevState) => ({ + ...prevState, + [data.chatMessageId]: true + })) + setMessages((prevMessages) => { + const allMessages = [...cloneDeep(prevMessages)] + const lastMessage = allMessages[allMessages.length - 1] + if (lastMessage.type === 'userMessage') return allMessages + if (lastMessage.id) return allMessages + allMessages[allMessages.length - 1].id = data.chatMessageId + return allMessages + }) setTtsStreamingState((prevState) => { - // Cleanup any existing streaming first if (prevState.audio) { prevState.audio.pause() if (prevState.audio.src) { @@ -1864,8 +1881,8 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP if (prevState.mediaSource && prevState.mediaSource.readyState === 'open') { try { prevState.mediaSource.endOfStream() - } catch (e) { - // Ignore errors during cleanup + } catch (error) { + console.error('Error stopping previous media source:', error) } } @@ -1875,12 +1892,11 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP audio: null, chunkQueue: [], isBuffering: false, - audioFormat: format + audioFormat: data.format } }) - // Initialize TTS streaming with the correct format - setTimeout(() => initializeTTSStreaming(format, onAudioReady), 0) + setTimeout(() => initializeTTSStreaming(data), 0) } const handleTTSDataChunk = (base64Data) => { @@ -1888,13 +1904,11 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP const audioBuffer = Uint8Array.from(atob(base64Data), (c) => c.charCodeAt(0)) setTtsStreamingState((prevState) => { - // Add chunk to queue const newState = { ...prevState, chunkQueue: [...prevState.chunkQueue, audioBuffer] } - // Process queue if sourceBuffer is ready if (prevState.sourceBuffer && !prevState.sourceBuffer.updating) { setTimeout(() => processChunkQueue(), 0) } @@ -1910,7 +1924,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP setTtsStreamingState((prevState) => { if (prevState.mediaSource && prevState.mediaSource.readyState === 'open') { try { - // Process any remaining chunks first if (prevState.sourceBuffer && prevState.chunkQueue.length > 0 && !prevState.sourceBuffer.updating) { const remainingChunks = [...prevState.chunkQueue] remainingChunks.forEach((chunk, index) => { @@ -1919,7 +1932,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP try { prevState.sourceBuffer.appendBuffer(chunk) if (index === remainingChunks.length - 1) { - // End stream after last chunk setTimeout(() => { if (prevState.mediaSource && prevState.mediaSource.readyState === 'open') { prevState.mediaSource.endOfStream() @@ -1938,11 +1950,9 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP } } - // Wait for any pending buffer operations to complete if (prevState.sourceBuffer && !prevState.sourceBuffer.updating) { prevState.mediaSource.endOfStream() } else if (prevState.sourceBuffer) { - // Wait for buffer to finish updating prevState.sourceBuffer.addEventListener( 'updateend', () => { @@ -1961,7 +1971,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP }) } - // Set up sourceBuffer event listeners when it changes useEffect(() => { if (ttsStreamingState.sourceBuffer) { const sourceBuffer = ttsStreamingState.sourceBuffer @@ -1971,7 +1980,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP ...prevState, isBuffering: false })) - // Process next chunk in queue setTimeout(() => processChunkQueue(), 0) } @@ -1983,7 +1991,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP } }, [ttsStreamingState.sourceBuffer]) - // Cleanup TTS streaming on component unmount useEffect(() => { return () => { cleanupTTSStreaming() @@ -2654,8 +2661,12 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP {isTTSEnabled && ( handleTTSClick(message.id, message.message)} - disabled={ttsLoading[message.id]} + onClick={() => + isTTSPlaying[message.id] + ? handleTTSStop(message.id) + : handleTTSClick(message.id, message.message) + } + disabled={isTTSLoading[message.id]} sx={{ backgroundColor: ttsAudio[message.id] ? 'primary.main' : 'transparent', color: ttsAudio[message.id] ? 'white' : 'inherit', @@ -2664,8 +2675,10 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP } }} > - {ttsLoading[message.id] ? ( + {isTTSLoading[message.id] ? ( + ) : isTTSPlaying[message.id] ? ( + ) : ( )}