Spaces:
Runtime error
Runtime error
File size: 5,728 Bytes
1eb7029 addbfa5 f037fe5 1eb7029 f037fe5 1eb7029 f037fe5 1eb7029 f037fe5 addbfa5 f037fe5 1eb7029 f037fe5 1eb7029 f037fe5 1eb7029 addbfa5 1eb7029 f037fe5 addbfa5 1eb7029 f037fe5 addbfa5 1eb7029 addbfa5 1eb7029 f037fe5 1eb7029 addbfa5 1eb7029 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gradio as gr
import torch
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor, TextStreamer
from qwen_omni_utils import process_mm_info
import soundfile as sf
import tempfile
import spaces
import gc
# Initialize the model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
def get_model():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
model = Qwen2_5OmniModel.from_pretrained(
"Qwen/Qwen2.5-Omni-7B",
torch_dtype=torch_dtype,
device_map="auto",
enable_audio_output=True,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
)
return model
model = get_model()
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
# System prompt
SYSTEM_PROMPT = {
"role": "system",
"content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."
}
# Voice options
VOICE_OPTIONS = {
"Chelsie (Female)": "Chelsie",
"Ethan (Male)": "Ethan"
}
@spaces.GPU(duration=120)
def process_input(video, text, voice_type, enable_audio_output):
try:
# Clear GPU memory before processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Prepare multimodal input
user_input = {
"text": text,
"video": video if video is not None else None,
}
# Prepare conversation history for model processing
conversation = [SYSTEM_PROMPT]
conversation.append({"role": "user", "content": user_input})
# Process multimedia information
try:
audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
except Exception as e:
print(f"Error processing multimedia: {str(e)}")
audios, images, videos = [], [], []
inputs = processor(
text=processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False),
videos=videos,
return_tensors="pt",
padding=True
)
# Move inputs to device and convert dtype
inputs = {k: v.to(device=model.device, dtype=model.dtype) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
# Generate response with streaming and audio output
text_ids = None
audio_path = None
if enable_audio_output:
voice_type_value = VOICE_OPTIONS.get(voice_type, "Chelsie")
try:
generation_output = model.generate(
**inputs,
use_audio_in_video=False,
return_audio=True,
spk=voice_type_value,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=TextStreamer(processor, skip_prompt=True)
)
if isinstance(generation_output, tuple) and len(generation_output) == 2:
text_ids, audio = generation_output
if audio is not None:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
sf.write(tmp_file.name, audio.reshape(-1).detach().cpu().numpy(), samplerate=24000)
audio_path = tmp_file.name
except Exception as e:
print(f"Error during audio generation: {str(e)}")
# Fall back to text-only generation if audio fails
if text_ids is None:
try:
text_ids = model.generate(
**inputs,
use_audio_in_video=False,
return_audio=False,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=TextStreamer(processor, skip_prompt=True)
)
except Exception as e:
print(f"Error during fallback text generation: {str(e)}")
# Decode text response
text_response = processor.batch_decode(text_ids, skip_special_tokens=True)[0] if text_ids is not None else "Error generating response."
return text_response.strip(), audio_path
except Exception as e:
print(f"Error in process_input: {str(e)}")
return "Error processing input.", None
# Gradio interface setup
with gr.Blocks() as demo:
gr.Markdown("## Qwen2.5-Omni-7B Multimodal Demo")
with gr.Row():
video_input = gr.Video(label="Upload Video (max 120s)", sources=["upload"], max_length=120)
prompt_input = gr.Textbox(label="Analysis Prompt", placeholder="Describe or ask about the video...")
voice_selection = gr.Dropdown(label="Voice Type", choices=list(VOICE_OPTIONS.keys()), value="Chelsie (Female)")
enable_audio_checkbox = gr.Checkbox(label="Enable Audio Output", value=True)
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Analysis Results", interactive=False)
audio_output = gr.Audio(label="Speech Response", autoplay=True)
submit_btn.click(
process_input,
inputs=[video_input, prompt_input, voice_selection, enable_audio_checkbox],
outputs=[text_output, audio_output]
)
demo.queue(concurrency_count=2)
demo.launch(server_name="0.0.0.0", server_port=7860)
|