Refactor TTS - fix issues with tts loading and stop audio buttons

This commit is contained in:
Ilango Rajagopal 2025-08-25 02:38:09 +05:30
parent 8de200ee15
commit 55b6be24df
8 changed files with 119 additions and 97 deletions

View File

@ -442,9 +442,9 @@ export interface IServerSideEventStreamer {
streamEndEvent(chatId: string): void
streamUsageMetadataEvent(chatId: string, data: any): void
streamAudioEvent(chatId: string, audioData: string): void
streamTTSStartEvent(chatId: string, format: string): void
streamTTSDataEvent(chatId: string, audioChunk: string): void
streamTTSEndEvent(chatId: string): void
streamTTSStartEvent(chatId: string, chatMessageId: string, format: string): void
streamTTSDataEvent(chatId: string, chatMessageId: string, audioChunk: string): void
streamTTSEndEvent(chatId: string, chatMessageId: string): void
}
export enum FollowUpPromptProvider {

View File

@ -50,7 +50,14 @@ const createAndStreamInternalPrediction = async (req: Request, res: Response, ne
databaseEntities: getRunningExpressApp().AppDataSource?.entityMetadatas || []
}
await generateTTSForResponseStream(apiResponse.text, chatflow.textToSpeech, options, apiResponse.chatId, sseStreamer)
await generateTTSForResponseStream(
apiResponse.text,
chatflow.textToSpeech,
options,
apiResponse.chatId,
apiResponse.chatMessageId,
sseStreamer
)
}
} catch (error) {
if (chatId) {

View File

@ -93,6 +93,7 @@ const createPrediction = async (req: Request, res: Response, next: NextFunction)
chatflow.textToSpeech,
options,
apiResponse.chatId,
apiResponse.chatMessageId,
sseStreamer
)
}

View File

@ -8,7 +8,7 @@ import { databaseEntities } from '../../utils'
const generateTextToSpeech = async (req: Request, res: Response) => {
try {
const { text, provider, credentialId, voice, model } = req.body
const { chatMessageId, text, provider, credentialId, voice, model } = req.body
if (!text) {
throw new InternalFlowiseError(
@ -60,7 +60,7 @@ const generateTextToSpeech = async (req: Request, res: Response) => {
(format: string) => {
const startResponse = {
event: 'tts_start',
data: { format }
data: { chatMessageId, format }
}
res.write('event: tts_start\n')
res.write(`data: ${JSON.stringify(startResponse)}\n\n`)
@ -69,7 +69,7 @@ const generateTextToSpeech = async (req: Request, res: Response) => {
const audioBase64 = chunk.toString('base64')
const clientResponse = {
event: 'tts_data',
data: audioBase64
data: { chatMessageId, audioChunk: audioBase64 }
}
res.write('event: tts_data\n')
res.write(`data: ${JSON.stringify(clientResponse)}\n\n`)
@ -77,7 +77,7 @@ const generateTextToSpeech = async (req: Request, res: Response) => {
async () => {
const endResponse = {
event: 'tts_end',
data: {}
data: { chatMessageId }
}
res.write('event: tts_end\n')
res.write(`data: ${JSON.stringify(endResponse)}\n\n`)

View File

@ -269,34 +269,34 @@ export class SSEStreamer implements IServerSideEventStreamer {
}
}
streamTTSStartEvent(chatId: string, format: string): void {
streamTTSStartEvent(chatId: string, chatMessageId: string, format: string): void {
const client = this.clients[chatId]
if (client) {
const clientResponse = {
event: 'tts_start',
data: { format }
data: { chatMessageId, format }
}
client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n')
}
}
streamTTSDataEvent(chatId: string, audioChunk: string): void {
streamTTSDataEvent(chatId: string, chatMessageId: string, audioChunk: string): void {
const client = this.clients[chatId]
if (client) {
const clientResponse = {
event: 'tts_data',
data: audioChunk
data: { chatMessageId, audioChunk }
}
client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n')
}
}
streamTTSEndEvent(chatId: string): void {
streamTTSEndEvent(chatId: string, chatMessageId: string): void {
const client = this.clients[chatId]
if (client) {
const clientResponse = {
event: 'tts_end',
data: {}
data: { chatMessageId }
}
client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n')
}

View File

@ -2049,7 +2049,7 @@ export const executeAgentFlow = async ({
}
if (sseStreamer) {
await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, sseStreamer)
await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, chatMessage?.id, sseStreamer)
}
}

View File

@ -94,6 +94,7 @@ const generateTTSForResponseStream = async (
textToSpeechConfig: string | undefined,
options: ICommonObject,
chatId: string,
chatMessageId: string,
sseStreamer: IServerSideEventStreamer
): Promise<void> => {
try {
@ -121,19 +122,19 @@ const generateTTSForResponseStream = async (
activeProviderConfig,
options,
(format: string) => {
sseStreamer.streamTTSStartEvent(chatId, format)
sseStreamer.streamTTSStartEvent(chatId, chatMessageId, format)
},
(chunk: Buffer) => {
const audioBase64 = chunk.toString('base64')
sseStreamer.streamTTSDataEvent(chatId, audioBase64)
sseStreamer.streamTTSDataEvent(chatId, chatMessageId, audioBase64)
},
() => {
sseStreamer.streamTTSEndEvent(chatId)
sseStreamer.streamTTSEndEvent(chatId, chatMessageId)
}
)
} catch (error) {
logger.error(`[server]: TTS streaming failed: ${getErrorMessage(error)}`)
sseStreamer.streamTTSEndEvent(chatId)
sseStreamer.streamTTSEndEvent(chatId, chatMessageId)
}
}
@ -902,9 +903,9 @@ export const executeFlow = async ({
}
if (streaming && sseStreamer) {
await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, sseStreamer)
await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, chatMessage?.id, sseStreamer)
} else if (sseStreamer) {
await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, sseStreamer)
await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, chatMessage?.id, sseStreamer)
}
}

View File

@ -39,7 +39,8 @@ import {
IconCheck,
IconPaperclip,
IconSparkles,
IconVolume
IconVolume,
IconSquare
} from '@tabler/icons-react'
import robotPNG from '@/assets/images/robot.png'
import userPNG from '@/assets/images/account.png'
@ -253,7 +254,8 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
const [isConfigLoading, setIsConfigLoading] = useState(true)
// TTS state
const [ttsLoading, setTtsLoading] = useState({})
const [isTTSLoading, setIsTTSLoading] = useState({})
const [isTTSPlaying, setIsTTSPlaying] = useState({})
const [ttsAudio, setTtsAudio] = useState({})
const [isTTSEnabled, setIsTTSEnabled] = useState(false)
@ -1053,10 +1055,10 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
handleAutoPlayAudio(payload.data)
break
case 'tts_start':
handleTTSStart(payload.data.format)
handleTTSStart(payload.data)
break
case 'tts_data':
handleTTSDataChunk(payload.data)
handleTTSDataChunk(payload.data.audioChunk)
break
case 'tts_end':
handleTTSEnd()
@ -1559,9 +1561,7 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
setIsLeadSaving(false)
}
const handleTTSClick = async (messageId, messageText) => {
if (ttsLoading[messageId]) return
const handleTTSStop = (messageId) => {
if (ttsAudio[messageId]) {
ttsAudio[messageId].pause()
ttsAudio[messageId].currentTime = 0
@ -1570,14 +1570,38 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
delete newState[messageId]
return newState
})
}
if (ttsStreamingState.audio) {
ttsStreamingState.audio.pause()
cleanupTTSStreaming()
}
setIsTTSPlaying((prev) => {
const newState = { ...prev }
delete newState[messageId]
return newState
})
setIsTTSLoading((prev) => {
const newState = { ...prev }
delete newState[messageId]
return newState
})
}
const handleTTSClick = async (messageId, messageText) => {
if (isTTSLoading[messageId]) return
if (isTTSPlaying[messageId] || ttsAudio[messageId]) {
handleTTSStop(messageId)
return
}
setTtsLoading((prev) => ({ ...prev, [messageId]: true }))
handleTTSStart({ chatMessageId: messageId, format: 'mp3' })
try {
let ttsConfig = null
if (getChatflowConfig.data && getChatflowConfig.data.textToSpeech) {
if (getChatflowConfig?.data?.textToSpeech) {
try {
ttsConfig =
typeof getChatflowConfig.data.textToSpeech === 'string'
@ -1592,7 +1616,7 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
let providerConfig = null
if (ttsConfig) {
Object.keys(ttsConfig).forEach((provider) => {
if (ttsConfig[provider] && ttsConfig[provider].status) {
if (ttsConfig?.[provider]?.status) {
activeProvider = provider
providerConfig = ttsConfig[provider]
}
@ -1607,19 +1631,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
return
}
// Use existing streaming infrastructure for manual TTS
handleTTSStart('mp3', (audio) => {
setTtsAudio((prev) => ({ ...prev, [messageId]: audio }))
audio.addEventListener('ended', () => {
setTtsAudio((prev) => {
const newState = { ...prev }
delete newState[messageId]
return newState
})
})
})
const response = await fetch('/api/v1/text-to-speech/generate', {
method: 'POST',
headers: {
@ -1628,6 +1639,8 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
},
credentials: 'include',
body: JSON.stringify({
chatId: chatId,
chatMessageId: messageId,
text: messageText,
provider: activeProvider,
credentialId: providerConfig.credentialId,
@ -1652,25 +1665,21 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
break
}
const value = result.value
// Decode the chunk as text and add to buffer
const chunk = decoder.decode(value, { stream: true })
buffer += chunk
// Process complete SSE events
const lines = buffer.split('\n\n')
buffer = lines.pop() || '' // Keep incomplete event in buffer
buffer = lines.pop() || ''
for (const eventBlock of lines) {
if (eventBlock.trim()) {
const event = parseSSEEvent(eventBlock)
if (event) {
// Handle the event just like the SSE handler does
switch (event.event) {
case 'tts_start':
break
case 'tts_data':
handleTTSDataChunk(event.data)
handleTTSDataChunk(event.data.audioChunk)
break
case 'tts_end':
handleTTSEnd()
@ -1689,7 +1698,7 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
options: { variant: 'error' }
})
} finally {
setTtsLoading((prev) => {
setIsTTSLoading((prev) => {
const newState = { ...prev }
delete newState[messageId]
return newState
@ -1699,7 +1708,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
const handleAutoPlayAudio = async (audioData) => {
try {
// Convert base64 audio data to blob and play
const audioBuffer = Uint8Array.from(atob(audioData), (c) => c.charCodeAt(0))
const audioBlob = new Blob([audioBuffer], { type: 'audio/mpeg' })
const audioUrl = URL.createObjectURL(audioBlob)
@ -1712,19 +1720,10 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
await audio.play()
} catch (error) {
console.error('Error playing auto TTS audio:', error)
// Fallback: Use manual TTS API call
const lastMessage = messages[messages.length - 1]
if (lastMessage && lastMessage.type === 'apiMessage' && lastMessage.message) {
try {
await handleTTSClick(lastMessage.id, lastMessage.message)
} catch (fallbackError) {
console.error('TTS fallback also failed:', fallbackError)
enqueueSnackbar({
message: 'Auto-play audio failed',
options: { variant: 'error' }
})
}
}
enqueueSnackbar({
message: 'Auto-play audio failed',
options: { variant: 'error' }
})
}
}
@ -1751,7 +1750,7 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
return event.event ? event : null
}
const initializeTTSStreaming = (format, onAudioReady = null) => {
const initializeTTSStreaming = (data) => {
try {
const mediaSource = new MediaSource()
const audio = new Audio()
@ -1759,9 +1758,7 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
mediaSource.addEventListener('sourceopen', () => {
try {
// Use the provided format, default to MP3 if not set
const mimeType = format === 'mp3' ? 'audio/mpeg' : 'audio/mpeg'
const mimeType = data.format === 'mp3' ? 'audio/mpeg' : 'audio/mpeg'
const sourceBuffer = mediaSource.addSourceBuffer(mimeType)
setTtsStreamingState((prevState) => ({
@ -1771,16 +1768,9 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
audio
}))
// Start playback
audio.play().catch((playError) => {
console.error('Error starting audio playback:', playError)
})
// Notify callback if provided
if (onAudioReady) {
onAudioReady(audio)
}
} catch (error) {
console.error('Error setting up source buffer:', error)
console.error('MediaSource readyState:', mediaSource.readyState)
@ -1788,7 +1778,24 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
}
})
audio.addEventListener('playing', () => {
setIsTTSLoading((prevState) => {
const newState = { ...prevState }
newState[data.chatMessageId] = false
return newState
})
setIsTTSPlaying((prevState) => ({
...prevState,
[data.chatMessageId]: true
}))
})
audio.addEventListener('ended', () => {
setIsTTSPlaying((prevState) => {
const newState = { ...prevState }
delete newState[data.chatMessageId]
return newState
})
cleanupTTSStreaming()
})
} catch (error) {
@ -1850,10 +1857,20 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
})
}
const handleTTSStart = (format, onAudioReady = null) => {
// Store the audio format for this TTS session and initialize
const handleTTSStart = (data) => {
setIsTTSLoading((prevState) => ({
...prevState,
[data.chatMessageId]: true
}))
setMessages((prevMessages) => {
const allMessages = [...cloneDeep(prevMessages)]
const lastMessage = allMessages[allMessages.length - 1]
if (lastMessage.type === 'userMessage') return allMessages
if (lastMessage.id) return allMessages
allMessages[allMessages.length - 1].id = data.chatMessageId
return allMessages
})
setTtsStreamingState((prevState) => {
// Cleanup any existing streaming first
if (prevState.audio) {
prevState.audio.pause()
if (prevState.audio.src) {
@ -1864,8 +1881,8 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
if (prevState.mediaSource && prevState.mediaSource.readyState === 'open') {
try {
prevState.mediaSource.endOfStream()
} catch (e) {
// Ignore errors during cleanup
} catch (error) {
console.error('Error stopping previous media source:', error)
}
}
@ -1875,12 +1892,11 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
audio: null,
chunkQueue: [],
isBuffering: false,
audioFormat: format
audioFormat: data.format
}
})
// Initialize TTS streaming with the correct format
setTimeout(() => initializeTTSStreaming(format, onAudioReady), 0)
setTimeout(() => initializeTTSStreaming(data), 0)
}
const handleTTSDataChunk = (base64Data) => {
@ -1888,13 +1904,11 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
const audioBuffer = Uint8Array.from(atob(base64Data), (c) => c.charCodeAt(0))
setTtsStreamingState((prevState) => {
// Add chunk to queue
const newState = {
...prevState,
chunkQueue: [...prevState.chunkQueue, audioBuffer]
}
// Process queue if sourceBuffer is ready
if (prevState.sourceBuffer && !prevState.sourceBuffer.updating) {
setTimeout(() => processChunkQueue(), 0)
}
@ -1910,7 +1924,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
setTtsStreamingState((prevState) => {
if (prevState.mediaSource && prevState.mediaSource.readyState === 'open') {
try {
// Process any remaining chunks first
if (prevState.sourceBuffer && prevState.chunkQueue.length > 0 && !prevState.sourceBuffer.updating) {
const remainingChunks = [...prevState.chunkQueue]
remainingChunks.forEach((chunk, index) => {
@ -1919,7 +1932,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
try {
prevState.sourceBuffer.appendBuffer(chunk)
if (index === remainingChunks.length - 1) {
// End stream after last chunk
setTimeout(() => {
if (prevState.mediaSource && prevState.mediaSource.readyState === 'open') {
prevState.mediaSource.endOfStream()
@ -1938,11 +1950,9 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
}
}
// Wait for any pending buffer operations to complete
if (prevState.sourceBuffer && !prevState.sourceBuffer.updating) {
prevState.mediaSource.endOfStream()
} else if (prevState.sourceBuffer) {
// Wait for buffer to finish updating
prevState.sourceBuffer.addEventListener(
'updateend',
() => {
@ -1961,7 +1971,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
})
}
// Set up sourceBuffer event listeners when it changes
useEffect(() => {
if (ttsStreamingState.sourceBuffer) {
const sourceBuffer = ttsStreamingState.sourceBuffer
@ -1971,7 +1980,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
...prevState,
isBuffering: false
}))
// Process next chunk in queue
setTimeout(() => processChunkQueue(), 0)
}
@ -1983,7 +1991,6 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
}
}, [ttsStreamingState.sourceBuffer])
// Cleanup TTS streaming on component unmount
useEffect(() => {
return () => {
cleanupTTSStreaming()
@ -2654,8 +2661,12 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
{isTTSEnabled && (
<IconButton
size='small'
onClick={() => handleTTSClick(message.id, message.message)}
disabled={ttsLoading[message.id]}
onClick={() =>
isTTSPlaying[message.id]
? handleTTSStop(message.id)
: handleTTSClick(message.id, message.message)
}
disabled={isTTSLoading[message.id]}
sx={{
backgroundColor: ttsAudio[message.id] ? 'primary.main' : 'transparent',
color: ttsAudio[message.id] ? 'white' : 'inherit',
@ -2664,8 +2675,10 @@ const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, previews, setP
}
}}
>
{ttsLoading[message.id] ? (
{isTTSLoading[message.id] ? (
<CircularProgress size={16} />
) : isTTSPlaying[message.id] ? (
<IconSquare size={16} />
) : (
<IconVolume size={16} />
)}