import gradio as gr import torch from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor, TextStreamer from qwen_omni_utils import process_mm_info import soundfile as sf import tempfile import spaces import gc # Initialize the model and processor device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16 def get_model(): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() model = Qwen2_5OmniModel.from_pretrained( "Qwen/Qwen2.5-Omni-7B", torch_dtype=torch_dtype, device_map="auto", enable_audio_output=True, low_cpu_mem_usage=True, attn_implementation="flash_attention_2" if torch.cuda.is_available() else None ) return model model = get_model() processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B") # System prompt SYSTEM_PROMPT = { "role": "system", "content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." } # Voice options VOICE_OPTIONS = { "Chelsie (Female)": "Chelsie", "Ethan (Male)": "Ethan" } @spaces.GPU(duration=120) def process_input(video, text, voice_type, enable_audio_output): try: # Clear GPU memory before processing if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # Prepare multimodal input user_input = { "text": text, "video": video if video is not None else None, } # Prepare conversation history for model processing conversation = [SYSTEM_PROMPT] conversation.append({"role": "user", "content": user_input}) # Process multimedia information try: audios, images, videos = process_mm_info(conversation, use_audio_in_video=False) except Exception as e: print(f"Error processing multimedia: {str(e)}") audios, images, videos = [], [], [] inputs = processor( text=processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False), videos=videos, return_tensors="pt", padding=True ) # Move inputs to device and convert dtype inputs = {k: v.to(device=model.device, dtype=model.dtype) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} # Generate response with streaming and audio output text_ids = None audio_path = None if enable_audio_output: voice_type_value = VOICE_OPTIONS.get(voice_type, "Chelsie") try: generation_output = model.generate( **inputs, use_audio_in_video=False, return_audio=True, spk=voice_type_value, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, streamer=TextStreamer(processor, skip_prompt=True) ) if isinstance(generation_output, tuple) and len(generation_output) == 2: text_ids, audio = generation_output if audio is not None: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: sf.write(tmp_file.name, audio.reshape(-1).detach().cpu().numpy(), samplerate=24000) audio_path = tmp_file.name except Exception as e: print(f"Error during audio generation: {str(e)}") # Fall back to text-only generation if audio fails if text_ids is None: try: text_ids = model.generate( **inputs, use_audio_in_video=False, return_audio=False, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, streamer=TextStreamer(processor, skip_prompt=True) ) except Exception as e: print(f"Error during fallback text generation: {str(e)}") # Decode text response text_response = processor.batch_decode(text_ids, skip_special_tokens=True)[0] if text_ids is not None else "Error generating response." return text_response.strip(), audio_path except Exception as e: print(f"Error in process_input: {str(e)}") return "Error processing input.", None # Gradio interface setup with gr.Blocks() as demo: gr.Markdown("## Qwen2.5-Omni-7B Multimodal Demo") with gr.Row(): video_input = gr.Video(label="Upload Video (max 120s)", sources=["upload"], max_length=120) prompt_input = gr.Textbox(label="Analysis Prompt", placeholder="Describe or ask about the video...") voice_selection = gr.Dropdown(label="Voice Type", choices=list(VOICE_OPTIONS.keys()), value="Chelsie (Female)") enable_audio_checkbox = gr.Checkbox(label="Enable Audio Output", value=True) submit_btn = gr.Button("Analyze", variant="primary") with gr.Column(): text_output = gr.Textbox(label="Analysis Results", interactive=False) audio_output = gr.Audio(label="Speech Response", autoplay=True) submit_btn.click( process_input, inputs=[video_input, prompt_input, voice_selection, enable_audio_checkbox], outputs=[text_output, audio_output] ) demo.queue(concurrency_count=2) demo.launch(server_name="0.0.0.0", server_port=7860)