diff --git a/packages/components/src/textToSpeech.ts b/packages/components/src/textToSpeech.ts index 6cddbc4be..5363b2fec 100644 --- a/packages/components/src/textToSpeech.ts +++ b/packages/components/src/textToSpeech.ts @@ -14,6 +14,7 @@ export const convertTextToSpeechStream = async ( text: string, textToSpeechConfig: ICommonObject, options: ICommonObject, + abortController: AbortController, onStart: (format: string) => void, onChunk: (chunk: Buffer) => void, onEnd: () => void @@ -33,29 +34,34 @@ export const convertTextToSpeechStream = async ( apiKey: credentialData.openAIApiKey }) - const response = await openai.audio.speech.create({ - model: 'gpt-4o-mini-tts', - voice: (textToSpeechConfig.voice || 'alloy') as - | 'alloy' - | 'ash' - | 'ballad' - | 'coral' - | 'echo' - | 'fable' - | 'nova' - | 'onyx' - | 'sage' - | 'shimmer', - input: text, - response_format: 'mp3' - }) + const response = await openai.audio.speech.create( + { + model: 'gpt-4o-mini-tts', + voice: (textToSpeechConfig.voice || 'alloy') as + | 'alloy' + | 'ash' + | 'ballad' + | 'coral' + | 'echo' + | 'fable' + | 'nova' + | 'onyx' + | 'sage' + | 'shimmer', + input: text, + response_format: 'mp3' + }, + { + signal: abortController.signal + } + ) const stream = response.body as unknown as Readable if (!stream) { throw new Error('Failed to get response stream') } - await processStreamWithRateLimit(stream, onChunk, onEnd, resolve, reject, 640, 20) + await processStreamWithRateLimit(stream, onChunk, onEnd, resolve, reject, 640, 20, abortController) break } @@ -66,17 +72,21 @@ export const convertTextToSpeechStream = async ( apiKey: credentialData.elevenLabsApiKey }) - const response = await client.textToSpeech.stream(textToSpeechConfig.voice || '21m00Tcm4TlvDq8ikWAM', { - text: text, - modelId: 'eleven_multilingual_v2' - }) + const response = await client.textToSpeech.stream( + textToSpeechConfig.voice || '21m00Tcm4TlvDq8ikWAM', + { + text: text, + modelId: 'eleven_multilingual_v2' + }, + { abortSignal: abortController.signal } + ) const stream = Readable.fromWeb(response as unknown as ReadableStream) if (!stream) { throw new Error('Failed to get response stream') } - await processStreamWithRateLimit(stream, onChunk, onEnd, resolve, reject, 640, 40) + await processStreamWithRateLimit(stream, onChunk, onEnd, resolve, reject, 640, 40, abortController) break } } @@ -99,7 +109,8 @@ const processStreamWithRateLimit = async ( resolve: () => void, reject: (error: any) => void, targetChunkSize: number = 640, - rateLimitMs: number = 20 + rateLimitMs: number = 20, + abortController: AbortController ) => { const TARGET_CHUNK_SIZE = targetChunkSize const RATE_LIMIT_MS = rateLimitMs @@ -109,6 +120,13 @@ const processStreamWithRateLimit = async ( const processChunks = async () => { while (!isEnded || buffer.length > 0) { + // Check if aborted + if (abortController.signal.aborted) { + stream.destroy() + reject(new Error('TTS generation aborted')) + return + } + if (buffer.length >= TARGET_CHUNK_SIZE) { const chunk = buffer.subarray(0, TARGET_CHUNK_SIZE) buffer = buffer.subarray(TARGET_CHUNK_SIZE) @@ -129,7 +147,9 @@ const processStreamWithRateLimit = async ( } stream.on('data', (chunk) => { - buffer = Buffer.concat([buffer, Buffer.from(chunk)]) + if (!abortController.signal.aborted) { + buffer = Buffer.concat([buffer, Buffer.from(chunk)]) + } }) stream.on('end', () => { @@ -140,6 +160,12 @@ const processStreamWithRateLimit = async ( reject(error) }) + // Handle abort signal + abortController.signal.addEventListener('abort', () => { + stream.destroy() + reject(new Error('TTS generation aborted')) + }) + processChunks().catch(reject) } diff --git a/packages/server/src/controllers/text-to-speech/index.ts b/packages/server/src/controllers/text-to-speech/index.ts index 985bb263a..0f8c44007 100644 --- a/packages/server/src/controllers/text-to-speech/index.ts +++ b/packages/server/src/controllers/text-to-speech/index.ts @@ -92,37 +92,51 @@ const generateTextToSpeech = async (req: Request, res: Response) => { model: model } - await convertTextToSpeechStream( - text, - textToSpeechConfig, - options, - (format: string) => { - const startResponse = { - event: 'tts_start', - data: { chatMessageId, format } + // Create and store AbortController + const abortController = new AbortController() + const ttsAbortId = `tts_${chatId}_${chatMessageId}` + appServer.abortControllerPool.add(ttsAbortId, abortController) + + try { + await convertTextToSpeechStream( + text, + textToSpeechConfig, + options, + abortController, + (format: string) => { + const startResponse = { + event: 'tts_start', + data: { chatMessageId, format } + } + res.write('event: tts_start\n') + res.write(`data: ${JSON.stringify(startResponse)}\n\n`) + }, + (chunk: Buffer) => { + const audioBase64 = chunk.toString('base64') + const clientResponse = { + event: 'tts_data', + data: { chatMessageId, audioChunk: audioBase64 } + } + res.write('event: tts_data\n') + res.write(`data: ${JSON.stringify(clientResponse)}\n\n`) + }, + async () => { + const endResponse = { + event: 'tts_end', + data: { chatMessageId } + } + res.write('event: tts_end\n') + res.write(`data: ${JSON.stringify(endResponse)}\n\n`) + res.end() + // Clean up from pool on successful completion + appServer.abortControllerPool.remove(ttsAbortId) } - res.write('event: tts_start\n') - res.write(`data: ${JSON.stringify(startResponse)}\n\n`) - }, - (chunk: Buffer) => { - const audioBase64 = chunk.toString('base64') - const clientResponse = { - event: 'tts_data', - data: { chatMessageId, audioChunk: audioBase64 } - } - res.write('event: tts_data\n') - res.write(`data: ${JSON.stringify(clientResponse)}\n\n`) - }, - async () => { - const endResponse = { - event: 'tts_end', - data: { chatMessageId } - } - res.write('event: tts_end\n') - res.write(`data: ${JSON.stringify(endResponse)}\n\n`) - res.end() - } - ) + ) + } catch (error) { + // Clean up from pool on error + appServer.abortControllerPool.remove(ttsAbortId) + throw error + } } catch (error) { if (!res.headersSent) { res.setHeader('Content-Type', 'text/event-stream') @@ -160,6 +174,11 @@ const abortTextToSpeech = async (req: Request, res: Response) => { const appServer = getRunningExpressApp() + // Abort the TTS generation using existing pool + const ttsAbortId = `tts_${chatId}_${chatMessageId}` + appServer.abortControllerPool.abort(ttsAbortId) + + // Send abort event to client appServer.sseStreamer.streamTTSAbortEvent(chatId, chatMessageId) res.json({ message: 'TTS stream aborted successfully', chatId, chatMessageId }) diff --git a/packages/server/src/utils/buildAgentflow.ts b/packages/server/src/utils/buildAgentflow.ts index bf6c16f5c..d28a039a9 100644 --- a/packages/server/src/utils/buildAgentflow.ts +++ b/packages/server/src/utils/buildAgentflow.ts @@ -2189,7 +2189,15 @@ export const executeAgentFlow = async ({ } if (sseStreamer) { - await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, chatMessage?.id, sseStreamer) + await generateTTSForResponseStream( + result.text, + chatflow.textToSpeech, + options, + chatId, + chatMessage?.id, + sseStreamer, + abortController + ) } } diff --git a/packages/server/src/utils/buildChatflow.ts b/packages/server/src/utils/buildChatflow.ts index 9c5babc04..0a4b582aa 100644 --- a/packages/server/src/utils/buildChatflow.ts +++ b/packages/server/src/utils/buildChatflow.ts @@ -95,7 +95,8 @@ const generateTTSForResponseStream = async ( options: ICommonObject, chatId: string, chatMessageId: string, - sseStreamer: IServerSideEventStreamer + sseStreamer: IServerSideEventStreamer, + abortController?: AbortController ): Promise => { try { if (!textToSpeechConfig) return @@ -121,6 +122,7 @@ const generateTTSForResponseStream = async ( responseText, activeProviderConfig, options, + abortController || new AbortController(), (format: string) => { sseStreamer.streamTTSStartEvent(chatId, chatMessageId, format) }, @@ -908,9 +910,25 @@ export const executeFlow = async ({ } if (streaming && sseStreamer) { - await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, chatMessage?.id, sseStreamer) + await generateTTSForResponseStream( + result.text, + chatflow.textToSpeech, + options, + chatId, + chatMessage?.id, + sseStreamer, + signal + ) } else if (sseStreamer) { - await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, chatMessage?.id, sseStreamer) + await generateTTSForResponseStream( + result.text, + chatflow.textToSpeech, + options, + chatId, + chatMessage?.id, + sseStreamer, + signal + ) } }