TTS abort controller
This commit is contained in:
parent
3198e7817e
commit
27da0b62bd
|
|
@ -14,6 +14,7 @@ export const convertTextToSpeechStream = async (
|
|||
text: string,
|
||||
textToSpeechConfig: ICommonObject,
|
||||
options: ICommonObject,
|
||||
abortController: AbortController,
|
||||
onStart: (format: string) => void,
|
||||
onChunk: (chunk: Buffer) => void,
|
||||
onEnd: () => void
|
||||
|
|
@ -33,29 +34,34 @@ export const convertTextToSpeechStream = async (
|
|||
apiKey: credentialData.openAIApiKey
|
||||
})
|
||||
|
||||
const response = await openai.audio.speech.create({
|
||||
model: 'gpt-4o-mini-tts',
|
||||
voice: (textToSpeechConfig.voice || 'alloy') as
|
||||
| 'alloy'
|
||||
| 'ash'
|
||||
| 'ballad'
|
||||
| 'coral'
|
||||
| 'echo'
|
||||
| 'fable'
|
||||
| 'nova'
|
||||
| 'onyx'
|
||||
| 'sage'
|
||||
| 'shimmer',
|
||||
input: text,
|
||||
response_format: 'mp3'
|
||||
})
|
||||
const response = await openai.audio.speech.create(
|
||||
{
|
||||
model: 'gpt-4o-mini-tts',
|
||||
voice: (textToSpeechConfig.voice || 'alloy') as
|
||||
| 'alloy'
|
||||
| 'ash'
|
||||
| 'ballad'
|
||||
| 'coral'
|
||||
| 'echo'
|
||||
| 'fable'
|
||||
| 'nova'
|
||||
| 'onyx'
|
||||
| 'sage'
|
||||
| 'shimmer',
|
||||
input: text,
|
||||
response_format: 'mp3'
|
||||
},
|
||||
{
|
||||
signal: abortController.signal
|
||||
}
|
||||
)
|
||||
|
||||
const stream = response.body as unknown as Readable
|
||||
if (!stream) {
|
||||
throw new Error('Failed to get response stream')
|
||||
}
|
||||
|
||||
await processStreamWithRateLimit(stream, onChunk, onEnd, resolve, reject, 640, 20)
|
||||
await processStreamWithRateLimit(stream, onChunk, onEnd, resolve, reject, 640, 20, abortController)
|
||||
break
|
||||
}
|
||||
|
||||
|
|
@ -66,17 +72,21 @@ export const convertTextToSpeechStream = async (
|
|||
apiKey: credentialData.elevenLabsApiKey
|
||||
})
|
||||
|
||||
const response = await client.textToSpeech.stream(textToSpeechConfig.voice || '21m00Tcm4TlvDq8ikWAM', {
|
||||
text: text,
|
||||
modelId: 'eleven_multilingual_v2'
|
||||
})
|
||||
const response = await client.textToSpeech.stream(
|
||||
textToSpeechConfig.voice || '21m00Tcm4TlvDq8ikWAM',
|
||||
{
|
||||
text: text,
|
||||
modelId: 'eleven_multilingual_v2'
|
||||
},
|
||||
{ abortSignal: abortController.signal }
|
||||
)
|
||||
|
||||
const stream = Readable.fromWeb(response as unknown as ReadableStream)
|
||||
if (!stream) {
|
||||
throw new Error('Failed to get response stream')
|
||||
}
|
||||
|
||||
await processStreamWithRateLimit(stream, onChunk, onEnd, resolve, reject, 640, 40)
|
||||
await processStreamWithRateLimit(stream, onChunk, onEnd, resolve, reject, 640, 40, abortController)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -99,7 +109,8 @@ const processStreamWithRateLimit = async (
|
|||
resolve: () => void,
|
||||
reject: (error: any) => void,
|
||||
targetChunkSize: number = 640,
|
||||
rateLimitMs: number = 20
|
||||
rateLimitMs: number = 20,
|
||||
abortController: AbortController
|
||||
) => {
|
||||
const TARGET_CHUNK_SIZE = targetChunkSize
|
||||
const RATE_LIMIT_MS = rateLimitMs
|
||||
|
|
@ -109,6 +120,13 @@ const processStreamWithRateLimit = async (
|
|||
|
||||
const processChunks = async () => {
|
||||
while (!isEnded || buffer.length > 0) {
|
||||
// Check if aborted
|
||||
if (abortController.signal.aborted) {
|
||||
stream.destroy()
|
||||
reject(new Error('TTS generation aborted'))
|
||||
return
|
||||
}
|
||||
|
||||
if (buffer.length >= TARGET_CHUNK_SIZE) {
|
||||
const chunk = buffer.subarray(0, TARGET_CHUNK_SIZE)
|
||||
buffer = buffer.subarray(TARGET_CHUNK_SIZE)
|
||||
|
|
@ -129,7 +147,9 @@ const processStreamWithRateLimit = async (
|
|||
}
|
||||
|
||||
stream.on('data', (chunk) => {
|
||||
buffer = Buffer.concat([buffer, Buffer.from(chunk)])
|
||||
if (!abortController.signal.aborted) {
|
||||
buffer = Buffer.concat([buffer, Buffer.from(chunk)])
|
||||
}
|
||||
})
|
||||
|
||||
stream.on('end', () => {
|
||||
|
|
@ -140,6 +160,12 @@ const processStreamWithRateLimit = async (
|
|||
reject(error)
|
||||
})
|
||||
|
||||
// Handle abort signal
|
||||
abortController.signal.addEventListener('abort', () => {
|
||||
stream.destroy()
|
||||
reject(new Error('TTS generation aborted'))
|
||||
})
|
||||
|
||||
processChunks().catch(reject)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -92,37 +92,51 @@ const generateTextToSpeech = async (req: Request, res: Response) => {
|
|||
model: model
|
||||
}
|
||||
|
||||
await convertTextToSpeechStream(
|
||||
text,
|
||||
textToSpeechConfig,
|
||||
options,
|
||||
(format: string) => {
|
||||
const startResponse = {
|
||||
event: 'tts_start',
|
||||
data: { chatMessageId, format }
|
||||
// Create and store AbortController
|
||||
const abortController = new AbortController()
|
||||
const ttsAbortId = `tts_${chatId}_${chatMessageId}`
|
||||
appServer.abortControllerPool.add(ttsAbortId, abortController)
|
||||
|
||||
try {
|
||||
await convertTextToSpeechStream(
|
||||
text,
|
||||
textToSpeechConfig,
|
||||
options,
|
||||
abortController,
|
||||
(format: string) => {
|
||||
const startResponse = {
|
||||
event: 'tts_start',
|
||||
data: { chatMessageId, format }
|
||||
}
|
||||
res.write('event: tts_start\n')
|
||||
res.write(`data: ${JSON.stringify(startResponse)}\n\n`)
|
||||
},
|
||||
(chunk: Buffer) => {
|
||||
const audioBase64 = chunk.toString('base64')
|
||||
const clientResponse = {
|
||||
event: 'tts_data',
|
||||
data: { chatMessageId, audioChunk: audioBase64 }
|
||||
}
|
||||
res.write('event: tts_data\n')
|
||||
res.write(`data: ${JSON.stringify(clientResponse)}\n\n`)
|
||||
},
|
||||
async () => {
|
||||
const endResponse = {
|
||||
event: 'tts_end',
|
||||
data: { chatMessageId }
|
||||
}
|
||||
res.write('event: tts_end\n')
|
||||
res.write(`data: ${JSON.stringify(endResponse)}\n\n`)
|
||||
res.end()
|
||||
// Clean up from pool on successful completion
|
||||
appServer.abortControllerPool.remove(ttsAbortId)
|
||||
}
|
||||
res.write('event: tts_start\n')
|
||||
res.write(`data: ${JSON.stringify(startResponse)}\n\n`)
|
||||
},
|
||||
(chunk: Buffer) => {
|
||||
const audioBase64 = chunk.toString('base64')
|
||||
const clientResponse = {
|
||||
event: 'tts_data',
|
||||
data: { chatMessageId, audioChunk: audioBase64 }
|
||||
}
|
||||
res.write('event: tts_data\n')
|
||||
res.write(`data: ${JSON.stringify(clientResponse)}\n\n`)
|
||||
},
|
||||
async () => {
|
||||
const endResponse = {
|
||||
event: 'tts_end',
|
||||
data: { chatMessageId }
|
||||
}
|
||||
res.write('event: tts_end\n')
|
||||
res.write(`data: ${JSON.stringify(endResponse)}\n\n`)
|
||||
res.end()
|
||||
}
|
||||
)
|
||||
)
|
||||
} catch (error) {
|
||||
// Clean up from pool on error
|
||||
appServer.abortControllerPool.remove(ttsAbortId)
|
||||
throw error
|
||||
}
|
||||
} catch (error) {
|
||||
if (!res.headersSent) {
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
|
|
@ -160,6 +174,11 @@ const abortTextToSpeech = async (req: Request, res: Response) => {
|
|||
|
||||
const appServer = getRunningExpressApp()
|
||||
|
||||
// Abort the TTS generation using existing pool
|
||||
const ttsAbortId = `tts_${chatId}_${chatMessageId}`
|
||||
appServer.abortControllerPool.abort(ttsAbortId)
|
||||
|
||||
// Send abort event to client
|
||||
appServer.sseStreamer.streamTTSAbortEvent(chatId, chatMessageId)
|
||||
|
||||
res.json({ message: 'TTS stream aborted successfully', chatId, chatMessageId })
|
||||
|
|
|
|||
|
|
@ -2189,7 +2189,15 @@ export const executeAgentFlow = async ({
|
|||
}
|
||||
|
||||
if (sseStreamer) {
|
||||
await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, chatMessage?.id, sseStreamer)
|
||||
await generateTTSForResponseStream(
|
||||
result.text,
|
||||
chatflow.textToSpeech,
|
||||
options,
|
||||
chatId,
|
||||
chatMessage?.id,
|
||||
sseStreamer,
|
||||
abortController
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -95,7 +95,8 @@ const generateTTSForResponseStream = async (
|
|||
options: ICommonObject,
|
||||
chatId: string,
|
||||
chatMessageId: string,
|
||||
sseStreamer: IServerSideEventStreamer
|
||||
sseStreamer: IServerSideEventStreamer,
|
||||
abortController?: AbortController
|
||||
): Promise<void> => {
|
||||
try {
|
||||
if (!textToSpeechConfig) return
|
||||
|
|
@ -121,6 +122,7 @@ const generateTTSForResponseStream = async (
|
|||
responseText,
|
||||
activeProviderConfig,
|
||||
options,
|
||||
abortController || new AbortController(),
|
||||
(format: string) => {
|
||||
sseStreamer.streamTTSStartEvent(chatId, chatMessageId, format)
|
||||
},
|
||||
|
|
@ -908,9 +910,25 @@ export const executeFlow = async ({
|
|||
}
|
||||
|
||||
if (streaming && sseStreamer) {
|
||||
await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, chatMessage?.id, sseStreamer)
|
||||
await generateTTSForResponseStream(
|
||||
result.text,
|
||||
chatflow.textToSpeech,
|
||||
options,
|
||||
chatId,
|
||||
chatMessage?.id,
|
||||
sseStreamer,
|
||||
signal
|
||||
)
|
||||
} else if (sseStreamer) {
|
||||
await generateTTSForResponseStream(result.text, chatflow.textToSpeech, options, chatId, chatMessage?.id, sseStreamer)
|
||||
await generateTTSForResponseStream(
|
||||
result.text,
|
||||
chatflow.textToSpeech,
|
||||
options,
|
||||
chatId,
|
||||
chatMessage?.id,
|
||||
sseStreamer,
|
||||
signal
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue