Sergidev's picture
v4
f037fe5
import gradio as gr
import torch
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor, TextStreamer
from qwen_omni_utils import process_mm_info
import soundfile as sf
import tempfile
import spaces
import gc
# Initialize the model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
def get_model():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
model = Qwen2_5OmniModel.from_pretrained(
"Qwen/Qwen2.5-Omni-7B",
torch_dtype=torch_dtype,
device_map="auto",
enable_audio_output=True,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
)
return model
model = get_model()
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
# System prompt
SYSTEM_PROMPT = {
"role": "system",
"content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."
}
# Voice options
VOICE_OPTIONS = {
"Chelsie (Female)": "Chelsie",
"Ethan (Male)": "Ethan"
}
@spaces.GPU(duration=120)
def process_input(video, text, voice_type, enable_audio_output):
try:
# Clear GPU memory before processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Prepare multimodal input
user_input = {
"text": text,
"video": video if video is not None else None,
}
# Prepare conversation history for model processing
conversation = [SYSTEM_PROMPT]
conversation.append({"role": "user", "content": user_input})
# Process multimedia information
try:
audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
except Exception as e:
print(f"Error processing multimedia: {str(e)}")
audios, images, videos = [], [], []
inputs = processor(
text=processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False),
videos=videos,
return_tensors="pt",
padding=True
)
# Move inputs to device and convert dtype
inputs = {k: v.to(device=model.device, dtype=model.dtype) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
# Generate response with streaming and audio output
text_ids = None
audio_path = None
if enable_audio_output:
voice_type_value = VOICE_OPTIONS.get(voice_type, "Chelsie")
try:
generation_output = model.generate(
**inputs,
use_audio_in_video=False,
return_audio=True,
spk=voice_type_value,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=TextStreamer(processor, skip_prompt=True)
)
if isinstance(generation_output, tuple) and len(generation_output) == 2:
text_ids, audio = generation_output
if audio is not None:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
sf.write(tmp_file.name, audio.reshape(-1).detach().cpu().numpy(), samplerate=24000)
audio_path = tmp_file.name
except Exception as e:
print(f"Error during audio generation: {str(e)}")
# Fall back to text-only generation if audio fails
if text_ids is None:
try:
text_ids = model.generate(
**inputs,
use_audio_in_video=False,
return_audio=False,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=TextStreamer(processor, skip_prompt=True)
)
except Exception as e:
print(f"Error during fallback text generation: {str(e)}")
# Decode text response
text_response = processor.batch_decode(text_ids, skip_special_tokens=True)[0] if text_ids is not None else "Error generating response."
return text_response.strip(), audio_path
except Exception as e:
print(f"Error in process_input: {str(e)}")
return "Error processing input.", None
# Gradio interface setup
with gr.Blocks() as demo:
gr.Markdown("## Qwen2.5-Omni-7B Multimodal Demo")
with gr.Row():
video_input = gr.Video(label="Upload Video (max 120s)", sources=["upload"], max_length=120)
prompt_input = gr.Textbox(label="Analysis Prompt", placeholder="Describe or ask about the video...")
voice_selection = gr.Dropdown(label="Voice Type", choices=list(VOICE_OPTIONS.keys()), value="Chelsie (Female)")
enable_audio_checkbox = gr.Checkbox(label="Enable Audio Output", value=True)
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Analysis Results", interactive=False)
audio_output = gr.Audio(label="Speech Response", autoplay=True)
submit_btn.click(
process_input,
inputs=[video_input, prompt_input, voice_selection, enable_audio_checkbox],
outputs=[text_output, audio_output]
)
demo.queue(concurrency_count=2)
demo.launch(server_name="0.0.0.0", server_port=7860)