File size: 5,728 Bytes
1eb7029
addbfa5
f037fe5
 
 
1eb7029
f037fe5
 
1eb7029
f037fe5
 
 
1eb7029
f037fe5
 
 
 
 
 
 
 
 
 
 
addbfa5
f037fe5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb7029
f037fe5
 
1eb7029
f037fe5
 
 
 
 
 
 
1eb7029
addbfa5
1eb7029
 
f037fe5
addbfa5
1eb7029
f037fe5
 
 
addbfa5
1eb7029
 
addbfa5
 
1eb7029
 
f037fe5
 
1eb7029
 
 
addbfa5
1eb7029
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import torch
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor, TextStreamer
from qwen_omni_utils import process_mm_info
import soundfile as sf
import tempfile
import spaces
import gc

# Initialize the model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16

def get_model():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
    model = Qwen2_5OmniModel.from_pretrained(
        "Qwen/Qwen2.5-Omni-7B",
        torch_dtype=torch_dtype,
        device_map="auto",
        enable_audio_output=True,
        low_cpu_mem_usage=True,
        attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
    )
    return model

model = get_model()
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")

# System prompt
SYSTEM_PROMPT = {
    "role": "system",
    "content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."
}

# Voice options
VOICE_OPTIONS = {
    "Chelsie (Female)": "Chelsie",
    "Ethan (Male)": "Ethan"
}

@spaces.GPU(duration=120)
def process_input(video, text, voice_type, enable_audio_output):
    try:
        # Clear GPU memory before processing
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()

        # Prepare multimodal input
        user_input = {
            "text": text,
            "video": video if video is not None else None,
        }

        # Prepare conversation history for model processing
        conversation = [SYSTEM_PROMPT]
        conversation.append({"role": "user", "content": user_input})

        # Process multimedia information
        try:
            audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
        except Exception as e:
            print(f"Error processing multimedia: {str(e)}")
            audios, images, videos = [], [], []

        inputs = processor(
            text=processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False),
            videos=videos,
            return_tensors="pt",
            padding=True
        )

        # Move inputs to device and convert dtype
        inputs = {k: v.to(device=model.device, dtype=model.dtype) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

        # Generate response with streaming and audio output
        text_ids = None
        audio_path = None

        if enable_audio_output:
            voice_type_value = VOICE_OPTIONS.get(voice_type, "Chelsie")
            try:
                generation_output = model.generate(
                    **inputs,
                    use_audio_in_video=False,
                    return_audio=True,
                    spk=voice_type_value,
                    max_new_tokens=512,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    streamer=TextStreamer(processor, skip_prompt=True)
                )
                if isinstance(generation_output, tuple) and len(generation_output) == 2:
                    text_ids, audio = generation_output
                    if audio is not None:
                        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
                            sf.write(tmp_file.name, audio.reshape(-1).detach().cpu().numpy(), samplerate=24000)
                            audio_path = tmp_file.name
            except Exception as e:
                print(f"Error during audio generation: {str(e)}")

        # Fall back to text-only generation if audio fails
        if text_ids is None:
            try:
                text_ids = model.generate(
                    **inputs,
                    use_audio_in_video=False,
                    return_audio=False,
                    max_new_tokens=512,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    streamer=TextStreamer(processor, skip_prompt=True)
                )
            except Exception as e:
                print(f"Error during fallback text generation: {str(e)}")

        # Decode text response
        text_response = processor.batch_decode(text_ids, skip_special_tokens=True)[0] if text_ids is not None else "Error generating response."

        return text_response.strip(), audio_path

    except Exception as e:
        print(f"Error in process_input: {str(e)}")
        return "Error processing input.", None

# Gradio interface setup
with gr.Blocks() as demo:
    gr.Markdown("## Qwen2.5-Omni-7B Multimodal Demo")

    with gr.Row():
        video_input = gr.Video(label="Upload Video (max 120s)", sources=["upload"], max_length=120)
        prompt_input = gr.Textbox(label="Analysis Prompt", placeholder="Describe or ask about the video...")

    voice_selection = gr.Dropdown(label="Voice Type", choices=list(VOICE_OPTIONS.keys()), value="Chelsie (Female)")
    enable_audio_checkbox = gr.Checkbox(label="Enable Audio Output", value=True)

    submit_btn = gr.Button("Analyze", variant="primary")

    with gr.Column():
        text_output = gr.Textbox(label="Analysis Results", interactive=False)
        audio_output = gr.Audio(label="Speech Response", autoplay=True)

    submit_btn.click(
        process_input,
        inputs=[video_input, prompt_input, voice_selection, enable_audio_checkbox],
        outputs=[text_output, audio_output]
    )

demo.queue(concurrency_count=2)
demo.launch(server_name="0.0.0.0", server_port=7860)