diff --git a/packages/components/nodes/chains/OpenAIMultiModalChain/OpenAIMultiModalChain.ts b/packages/components/nodes/chains/OpenAIMultiModalChain/OpenAIMultiModalChain.ts index f62d58bc0..a3f7e8158 100644 --- a/packages/components/nodes/chains/OpenAIMultiModalChain/OpenAIMultiModalChain.ts +++ b/packages/components/nodes/chains/OpenAIMultiModalChain/OpenAIMultiModalChain.ts @@ -1,15 +1,9 @@ -import { - ICommonObject, - INode, - INodeData, - INodeOutputsValue, - INodeParams -} from "../../../src/Interface"; +import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/Interface' import { getBaseClasses, getCredentialData, getCredentialParam, handleEscapeCharacters } from '../../../src/utils' -import { OpenAIMultiModalChainInput, VLLMChain } from "./VLLMChain"; +import { OpenAIMultiModalChainInput, VLLMChain } from './VLLMChain' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { formatResponse } from '../../outputparsers/OutputParserHelpers' -import { checkInputs, Moderation, streamResponse } from "../../moderation/Moderation"; +import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation' class OpenAIMultiModalChain_Chains implements INode { label: string @@ -72,7 +66,7 @@ class OpenAIMultiModalChain_Chains implements INode { label: 'Speech to Text', name: 'speechToText', type: 'boolean', - optional: true, + optional: true }, // TODO: only show when speechToText is true { @@ -84,7 +78,8 @@ class OpenAIMultiModalChain_Chains implements INode { { label: 'Transcriptions', name: 'transcriptions', - description: 'Transcribe audio into whatever language the audio is in. Default method when Speech to Text is turned on.' + description: + 'Transcribe audio into whatever language the audio is in. Default method when Speech to Text is turned on.' }, { label: 'Translations', @@ -186,7 +181,6 @@ class OpenAIMultiModalChain_Chains implements INode { const topP = nodeData.inputs?.topP as string const speechToText = nodeData.inputs?.speechToText as boolean - const fields: OpenAIMultiModalChainInput = { openAIApiKey: openAIApiKey, imageResolution: imageResolution, @@ -256,6 +250,22 @@ const runPrediction = async ( const socketIO = isStreaming ? options.socketIO : undefined const socketIOClientId = isStreaming ? options.socketIOClientId : '' const moderations = nodeData.inputs?.inputModeration as Moderation[] + const speechToText = nodeData.inputs?.speechToText as boolean + + if (options?.uploads) { + if (options.uploads.length === 1 && input.length === 0) { + if (speechToText) { + //special case, text input is empty, but we have an upload (recorded audio) + const convertedText = await chain.processAudioWithWisper(options.uploads[0], undefined) + //so we use the upload as input + input = convertedText + } + // do not send the audio file to the model + } else { + chain.uploads = options.uploads + } + } + if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the LLM chain @@ -273,9 +283,6 @@ const runPrediction = async ( * TO: { "value": "hello i am ben\n\n\thow are you?" } */ const promptValues = handleEscapeCharacters(promptValuesRaw, true) - if (options?.uploads) { - chain.uploads = options.uploads - } if (promptValues && inputVariables.length > 0) { let seen: string[] = [] diff --git a/packages/components/nodes/chains/OpenAIMultiModalChain/VLLMChain.ts b/packages/components/nodes/chains/OpenAIMultiModalChain/VLLMChain.ts index 2cf2ce95c..5fcb62520 100644 --- a/packages/components/nodes/chains/OpenAIMultiModalChain/VLLMChain.ts +++ b/packages/components/nodes/chains/OpenAIMultiModalChain/VLLMChain.ts @@ -101,42 +101,20 @@ export class VLLMChain extends BaseChain implements OpenAIMultiModalChainInput { }) if (this.speechToTextMode && this.uploads && this.uploads.length > 0) { const audioUploads = this.getAudioUploads(this.uploads) - for (const url of audioUploads) { - const filePath = path.join(getUserHome(), '.flowise', 'gptvision', url.data, url.name) - - // as the image is stored in the server, read the file and convert it to base64 - const audio_file = fs.createReadStream(filePath) - if (this.speechToTextMode.purpose === 'transcriptions') { - const transcription = await this.client.audio.transcriptions.create({ - file: audio_file, - model: 'whisper-1' - }) - chatMessages.push({ - type: 'text', - text: transcription.text - }) - } else if (this.speechToTextMode.purpose === 'translations') { - const translation = await this.client.audio.translations.create({ - file: audio_file, - model: 'whisper-1' - }) - chatMessages.push({ - type: 'text', - text: translation.text - }) - } + for (const upload of audioUploads) { + await this.processAudioWithWisper(upload, chatMessages) } } if (this.uploads && this.uploads.length > 0) { const imageUploads = this.getImageUploads(this.uploads) - for (const url of imageUploads) { - let bf = url.data - if (url.type == 'stored-file') { - const filePath = path.join(getUserHome(), '.flowise', 'gptvision', url.data, url.name) + for (const upload of imageUploads) { + let bf = upload.data + if (upload.type == 'stored-file') { + const filePath = path.join(getUserHome(), '.flowise', 'gptvision', upload.data, upload.name) // as the image is stored in the server, read the file and convert it to base64 const contents = fs.readFileSync(filePath) - bf = 'data:' + url.mime + ';base64,' + contents.toString('base64') + bf = 'data:' + upload.mime + ';base64,' + contents.toString('base64') } chatMessages.push({ type: 'image_url', @@ -182,6 +160,40 @@ export class VLLMChain extends BaseChain implements OpenAIMultiModalChainInput { } } + public async processAudioWithWisper(upload: IFileUpload, chatMessages: ChatCompletionContentPart[] | undefined): Promise { + const filePath = path.join(getUserHome(), '.flowise', 'gptvision', upload.data, upload.name) + + // as the image is stored in the server, read the file and convert it to base64 + const audio_file = fs.createReadStream(filePath) + if (this.speechToTextMode === 'transcriptions') { + const transcription = await this.client.audio.transcriptions.create({ + file: audio_file, + model: 'whisper-1' + }) + if (chatMessages) { + chatMessages.push({ + type: 'text', + text: transcription.text + }) + } + return transcription.text + } else if (this.speechToTextMode === 'translations') { + const translation = await this.client.audio.translations.create({ + file: audio_file, + model: 'whisper-1' + }) + if (chatMessages) { + chatMessages.push({ + type: 'text', + text: translation.text + }) + } + return translation.text + } + //should never get here + return '' + } + getAudioUploads = (urls: any[]) => { return urls.filter((url: any) => url.mime.startsWith('audio/')) } diff --git a/packages/ui/src/views/chatmessage/ChatMessage.js b/packages/ui/src/views/chatmessage/ChatMessage.js index 0d969c5e2..82b17ded6 100644 --- a/packages/ui/src/views/chatmessage/ChatMessage.js +++ b/packages/ui/src/views/chatmessage/ChatMessage.js @@ -304,10 +304,11 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { setRecordingNotSupported(false) } - const onRecordingStopped = () => { + const onRecordingStopped = async () => { stopAudioRecording(addRecordingToPreviews) setIsRecording(false) setRecordingNotSupported(false) + handlePromptClick('') } const onSourceDialogClick = (data, title) => { @@ -366,7 +367,9 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => { if (e) e.preventDefault() if (!promptStarterInput && userInput.trim() === '') { - return + if (!(previews.length === 1 && previews[0].type === 'audio')) { + return + } } let input = userInput @@ -626,7 +629,7 @@ export const ChatMessage = ({ open, chatflowid, isDialog }) => {