ajsbsd commited on
Commit
589237f
·
verified ·
1 Parent(s): 266db90

Qwen3-235B-A22B

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ SpeechT5Processor,
7
+ SpeechT5ForTextToSpeech,
8
+ SpeechT5HifiGan,
9
+ WhisperProcessor,
10
+ WhisperForConditionalGeneration
11
+ )
12
+ from datasets import load_dataset
13
+ import os
14
+ import spaces
15
+ import tempfile
16
+ import soundfile as sf
17
+ import librosa
18
+
19
+ # --- Configuration ---
20
+ HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
21
+ TORCH_DTYPE = torch.bfloat16
22
+ MAX_NEW_TOKENS = 512
23
+ DO_SAMPLE = True
24
+ TEMPERATURE = 0.7
25
+ TOP_K = 50
26
+ TOP_P = 0.95
27
+
28
+ TTS_MODEL_ID = "microsoft/speecht5_tts"
29
+ TTS_VOCODER_ID = "microsoft/speecht5_hifigan"
30
+ STT_MODEL_ID = "openai/whisper-small"
31
+
32
+ # --- Global Variables ---
33
+ tokenizer = None
34
+ llm_model = None
35
+ tts_processor = None
36
+ tts_model = None
37
+ tts_vocoder = None
38
+ speaker_embeddings = None
39
+ whisper_processor = None
40
+ whisper_model = None
41
+ first_load = True
42
+
43
+ # --- Load Models Function ---
44
+ @spaces.GPU
45
+ def load_models():
46
+ global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
47
+ global whisper_processor, whisper_model
48
+
49
+ if (tokenizer is not None and llm_model is not None and tts_model is not None and
50
+ whisper_model is not None):
51
+ print("All models already loaded.")
52
+ return
53
+
54
+ hf_token = os.environ.get("HF_TOKEN")
55
+
56
+ # LLM
57
+ try:
58
+ tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token)
59
+ if tokenizer.pad_token is None:
60
+ tokenizer.pad_token = tokenizer.eos_token
61
+ llm_model = AutoModelForCausalLM.from_pretrained(
62
+ HUGGINGFACE_MODEL_ID,
63
+ torch_dtype=TORCH_DTYPE,
64
+ device_map="auto",
65
+ token=hf_token
66
+ ).eval()
67
+ print("LLM loaded successfully.")
68
+ except Exception as e:
69
+ print(f"Error loading LLM: {e}")
70
+
71
+ # TTS
72
+ try:
73
+ tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token)
74
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
75
+ tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)
76
+ embeddings = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
77
+ speaker_embeddings = torch.tensor(embeddings[7306]["xvector"]).unsqueeze(0)
78
+ device = llm_model.device if llm_model else 'cpu'
79
+ tts_model.to(device)
80
+ tts_vocoder.to(device)
81
+ speaker_embeddings = speaker_embeddings.to(device)
82
+ print("TTS models loaded.")
83
+ except Exception as e:
84
+ print(f"Error loading TTS: {e}")
85
+
86
+ # STT
87
+ try:
88
+ whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token)
89
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token)
90
+ whisper_model.to(llm_model.device if llm_model else 'cpu')
91
+ print("Whisper loaded.")
92
+ except Exception as e:
93
+ print(f"Error loading Whisper: {e}")
94
+
95
+ # --- Generate Response + Audio ---
96
+ @spaces.GPU
97
+ def generate_response_and_audio(message, history):
98
+ global first_load
99
+ if first_load:
100
+ load_models()
101
+ first_load = False
102
+
103
+ global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
104
+
105
+ if tokenizer is None or llm_model is None:
106
+ return [{"role": "assistant", "content": "Error: LLM not loaded."}], None
107
+
108
+ messages = history.copy()
109
+ messages.append({"role": "user", "content": message})
110
+
111
+ try:
112
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
113
+ except:
114
+ input_text = ""
115
+ for item in history:
116
+ input_text += f"{item['role'].capitalize()}: {item['content']}\n"
117
+ input_text += f"User: {message}\nAssistant:"
118
+
119
+ try:
120
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(llm_model.device)
121
+ output_ids = llm_model.generate(
122
+ inputs["input_ids"],
123
+ attention_mask=inputs["attention_mask"],
124
+ max_new_tokens=MAX_NEW_TOKENS,
125
+ do_sample=DO_SAMPLE,
126
+ temperature=TEMPERATURE,
127
+ top_k=TOP_K,
128
+ top_p=TOP_P,
129
+ pad_token_id=tokenizer.eos_token_id
130
+ )
131
+ generated_text = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip()
132
+ except Exception as e:
133
+ print(f"LLM error: {e}")
134
+ return history + [{"role": "assistant", "content": "I had an issue generating a response."}], None
135
+
136
+ audio_path = None
137
+ if None not in [tts_processor, tts_model, tts_vocoder, speaker_embeddings]:
138
+ try:
139
+ tts_inputs = tts_processor(text=generated_text, return_tensors="pt", max_length=550, truncation=True).to(llm_model.device)
140
+ speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
141
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
142
+ audio_path = tmp_file.name
143
+ sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
144
+ except Exception as e:
145
+ print(f"TTS error: {e}")
146
+
147
+ return history + [{"role": "assistant", "content": generated_text}], audio_path
148
+
149
+ # --- Transcribe Audio ---
150
+ @spaces.GPU
151
+ def transcribe_audio(filepath):
152
+ global first_load
153
+ if first_load:
154
+ load_models()
155
+ first_load = False
156
+
157
+ global whisper_processor, whisper_model
158
+ if whisper_model is None:
159
+ return "Whisper model not loaded."
160
+
161
+ try:
162
+ audio, sr = librosa.load(filepath, sr=16000)
163
+ inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt").input_features.to(whisper_model.device)
164
+ outputs = whisper_model.generate(inputs)
165
+ return whisper_processor.batch_decode(outputs, skip_special_tokens=True)[0]
166
+ except Exception as e:
167
+ return f"Transcription failed: {e}"
168
+
169
+ # --- Gradio UI ---
170
+ with gr.Blocks() as demo:
171
+ gr.Markdown("# Qwen2.5 Chatbot with Voice Input/Output")
172
+
173
+ with gr.Tab("Chat"):
174
+ chatbot = gr.Chatbot(type='messages')
175
+ text_input = gr.Textbox(placeholder="Type your message...")
176
+ audio_output = gr.Audio(label="Response Audio", autoplay=True)
177
+ text_input.submit(generate_response_and_audio, [text_input, chatbot], [chatbot, audio_output])
178
+
179
+ with gr.Tab("Transcribe"):
180
+ audio_input = gr.Audio(type="filepath", label="Upload Audio")
181
+ transcribed = gr.Textbox(label="Transcription")
182
+ audio_input.upload(transcribe_audio, audio_input, transcribed)
183
+
184
+ clear_btn = gr.Button("Clear All")
185
+ clear_btn.click(lambda: ([], "", None), None, [chatbot, text_input, audio_output])
186
+
187
+ demo.queue().launch()