Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor, TextStreamer | |
| from qwen_omni_utils import process_mm_info | |
| import soundfile as sf | |
| import tempfile | |
| import spaces | |
| import gc | |
| # Initialize the model and processor | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16 | |
| def get_model(): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| model = Qwen2_5OmniModel.from_pretrained( | |
| "Qwen/Qwen2.5-Omni-7B", | |
| torch_dtype=torch_dtype, | |
| device_map="auto", | |
| enable_audio_output=True, | |
| low_cpu_mem_usage=True, | |
| attn_implementation="flash_attention_2" if torch.cuda.is_available() else None | |
| ) | |
| return model | |
| model = get_model() | |
| processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B") | |
| # System prompt | |
| SYSTEM_PROMPT = { | |
| "role": "system", | |
| "content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." | |
| } | |
| # Voice options | |
| VOICE_OPTIONS = { | |
| "Chelsie (Female)": "Chelsie", | |
| "Ethan (Male)": "Ethan" | |
| } | |
| def process_input(video, text, voice_type, enable_audio_output): | |
| try: | |
| # Clear GPU memory before processing | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Prepare multimodal input | |
| user_input = { | |
| "text": text, | |
| "video": video if video is not None else None, | |
| } | |
| # Prepare conversation history for model processing | |
| conversation = [SYSTEM_PROMPT] | |
| conversation.append({"role": "user", "content": user_input}) | |
| # Process multimedia information | |
| try: | |
| audios, images, videos = process_mm_info(conversation, use_audio_in_video=False) | |
| except Exception as e: | |
| print(f"Error processing multimedia: {str(e)}") | |
| audios, images, videos = [], [], [] | |
| inputs = processor( | |
| text=processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False), | |
| videos=videos, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| # Move inputs to device and convert dtype | |
| inputs = {k: v.to(device=model.device, dtype=model.dtype) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} | |
| # Generate response with streaming and audio output | |
| text_ids = None | |
| audio_path = None | |
| if enable_audio_output: | |
| voice_type_value = VOICE_OPTIONS.get(voice_type, "Chelsie") | |
| try: | |
| generation_output = model.generate( | |
| **inputs, | |
| use_audio_in_video=False, | |
| return_audio=True, | |
| spk=voice_type_value, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| streamer=TextStreamer(processor, skip_prompt=True) | |
| ) | |
| if isinstance(generation_output, tuple) and len(generation_output) == 2: | |
| text_ids, audio = generation_output | |
| if audio is not None: | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
| sf.write(tmp_file.name, audio.reshape(-1).detach().cpu().numpy(), samplerate=24000) | |
| audio_path = tmp_file.name | |
| except Exception as e: | |
| print(f"Error during audio generation: {str(e)}") | |
| # Fall back to text-only generation if audio fails | |
| if text_ids is None: | |
| try: | |
| text_ids = model.generate( | |
| **inputs, | |
| use_audio_in_video=False, | |
| return_audio=False, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| streamer=TextStreamer(processor, skip_prompt=True) | |
| ) | |
| except Exception as e: | |
| print(f"Error during fallback text generation: {str(e)}") | |
| # Decode text response | |
| text_response = processor.batch_decode(text_ids, skip_special_tokens=True)[0] if text_ids is not None else "Error generating response." | |
| return text_response.strip(), audio_path | |
| except Exception as e: | |
| print(f"Error in process_input: {str(e)}") | |
| return "Error processing input.", None | |
| # Gradio interface setup | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Qwen2.5-Omni-7B Multimodal Demo") | |
| with gr.Row(): | |
| video_input = gr.Video(label="Upload Video (max 120s)", sources=["upload"], max_length=120) | |
| prompt_input = gr.Textbox(label="Analysis Prompt", placeholder="Describe or ask about the video...") | |
| voice_selection = gr.Dropdown(label="Voice Type", choices=list(VOICE_OPTIONS.keys()), value="Chelsie (Female)") | |
| enable_audio_checkbox = gr.Checkbox(label="Enable Audio Output", value=True) | |
| submit_btn = gr.Button("Analyze", variant="primary") | |
| with gr.Column(): | |
| text_output = gr.Textbox(label="Analysis Results", interactive=False) | |
| audio_output = gr.Audio(label="Speech Response", autoplay=True) | |
| submit_btn.click( | |
| process_input, | |
| inputs=[video_input, prompt_input, voice_selection, enable_audio_checkbox], | |
| outputs=[text_output, audio_output] | |
| ) | |
| demo.queue(concurrency_count=2) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |