ajsbsd commited on
Commit
5199c19
·
verified ·
1 Parent(s): 77557c0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -0
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ SpeechT5Processor,
7
+ SpeechT5ForTextToSpeech,
8
+ SpeechT5HifiGan,
9
+ WhisperProcessor,
10
+ WhisperForConditionalGeneration
11
+ )
12
+ from datasets import load_dataset
13
+ import os
14
+ import spaces
15
+ import tempfile
16
+ import soundfile as sf
17
+ import librosa
18
+
19
+ # --- Configuration ---
20
+ HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
21
+ TORCH_DTYPE = torch.bfloat16
22
+ MAX_NEW_TOKENS = 512
23
+ DO_SAMPLE = True
24
+ TEMPERATURE = 0.7
25
+ TOP_K = 50
26
+ TOP_P = 0.95
27
+
28
+ TTS_MODEL_ID = "microsoft/speecht5_tts"
29
+ TTS_VOCODER_ID = "microsoft/speecht5_hifigan"
30
+ STT_MODEL_ID = "openai/whisper-small"
31
+
32
+ # --- Global Variables ---
33
+ tokenizer = None
34
+ llm_model = None
35
+ tts_processor = None
36
+ tts_model = None
37
+ tts_vocoder = None
38
+ speaker_embeddings = None
39
+ whisper_processor = None
40
+ whisper_model = None
41
+ first_load = True
42
+
43
+ # --- Helper: Split Text Into Chunks ---
44
+ def split_text_into_chunks(text, max_chars=400):
45
+ sentences = text.replace("...", ".").split(". ")
46
+ chunks = []
47
+ current_chunk = ""
48
+ for sentence in sentences:
49
+ if len(current_chunk) + len(sentence) + 2 < max_chars:
50
+ current_chunk += ". " + sentence if current_chunk else sentence
51
+ else:
52
+ chunks.append(current_chunk)
53
+ current_chunk = sentence
54
+ if current_chunk:
55
+ chunks.append(current_chunk)
56
+ return [f"{chunk}." for chunk in chunks if chunk.strip()]
57
+
58
+ # --- Load Models Function ---
59
+ @spaces.GPU
60
+ def load_models():
61
+ global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings, whisper_processor, whisper_model
62
+
63
+ hf_token = os.environ.get("HF_TOKEN")
64
+
65
+ # LLM
66
+ if tokenizer is None or llm_model is None:
67
+ try:
68
+ tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token)
69
+ if tokenizer.pad_token is None:
70
+ tokenizer.pad_token = tokenizer.eos_token
71
+ llm_model = AutoModelForCausalLM.from_pretrained(
72
+ HUGGINGFACE_MODEL_ID,
73
+ torch_dtype=TORCH_DTYPE,
74
+ device_map="auto",
75
+ token=hf_token
76
+ ).eval()
77
+ print("LLM loaded successfully.")
78
+ except Exception as e:
79
+ print(f"Error loading LLM: {e}")
80
+
81
+ # TTS
82
+ if tts_processor is None or tts_model is None or tts_vocoder is None:
83
+ try:
84
+ tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token)
85
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
86
+ tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)
87
+ embeddings = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
88
+ speaker_embeddings = torch.tensor(embeddings[7306]["xvector"]).unsqueeze(0)
89
+ device = llm_model.device if llm_model else 'cpu'
90
+ tts_model.to(device)
91
+ tts_vocoder.to(device)
92
+ speaker_embeddings = speaker_embeddings.to(device)
93
+ print("TTS models loaded.")
94
+ except Exception as e:
95
+ print(f"Error loading TTS: {e}")
96
+
97
+ # STT
98
+ if whisper_processor is None or whisper_model is None:
99
+ try:
100
+ whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token)
101
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token)
102
+ device = llm_model.device if llm_model else 'cpu'
103
+ whisper_model.to(device)
104
+ print("Whisper loaded.")
105
+ except Exception as e:
106
+ print(f"Error loading Whisper: {e}")
107
+
108
+ # --- Generate Response and Audio ---
109
+ @spaces.GPU
110
+ def generate_response_and_audio(message, history):
111
+ global first_load
112
+ if first_load:
113
+ load_models()
114
+ first_load = False
115
+
116
+ global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
117
+
118
+ if tokenizer is None or llm_model is None:
119
+ return [{"role": "assistant", "content": "Error: LLM not loaded."}], None
120
+
121
+ messages = history.copy()
122
+ messages.append({"role": "user", "content": message})
123
+
124
+ try:
125
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
126
+ except:
127
+ input_text = ""
128
+ for item in history:
129
+ input_text += f"{item['role'].capitalize()}: {item['content']}\n"
130
+ input_text += f"User: {message}\nAssistant:"
131
+
132
+ try:
133
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(llm_model.device)
134
+ output_ids = llm_model.generate(
135
+ inputs["input_ids"],
136
+ attention_mask=inputs["attention_mask"],
137
+ max_new_tokens=MAX_NEW_TOKENS,
138
+ do_sample=DO_SAMPLE,
139
+ temperature=TEMPERATURE,
140
+ top_k=TOP_K,
141
+ top_p=TOP_P,
142
+ pad_token_id=tokenizer.eos_token_id
143
+ )
144
+ generated_text = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip()
145
+ except Exception as e:
146
+ print(f"LLM error: {e}")
147
+ return history + [{"role": "assistant", "content": "I had an issue generating a response."}], None
148
+
149
+ audio_path = None
150
+ if None not in [tts_processor, tts_model, tts_vocoder, speaker_embeddings]:
151
+ try:
152
+ device = llm_model.device
153
+ text_chunks = split_text_into_chunks(generated_text)
154
+
155
+ full_speech = []
156
+ for chunk in text_chunks:
157
+ tts_inputs = tts_processor(text=chunk, return_tensors="pt", max_length=512, truncation=True).to(device)
158
+ speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
159
+ full_speech.append(speech.cpu())
160
+
161
+ full_speech_tensor = torch.cat(full_speech, dim=0)
162
+
163
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
164
+ audio_path = tmp_file.name
165
+ sf.write(audio_path, full_speech_tensor.numpy(), samplerate=16000)
166
+
167
+ except Exception as e:
168
+ print(f"TTS error: {e}")
169
+
170
+ return history + [{"role": "assistant", "content": generated_text}], audio_path
171
+
172
+ # --- Transcribe Audio ---
173
+ @spaces.GPU
174
+ def transcribe_audio(filepath):
175
+ global first_load
176
+ if first_load:
177
+ load_models()
178
+ first_load = False
179
+
180
+ global whisper_processor, whisper_model
181
+ if whisper_model is None:
182
+ return "Whisper model not loaded."
183
+
184
+ try:
185
+ audio, sr = librosa.load(filepath, sr=16000)
186
+ inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt").input_features.to(whisper_model.device)
187
+ outputs = whisper_model.generate(inputs)
188
+ return whisper_processor.batch_decode(outputs, skip_special_tokens=True)[0]
189
+ except Exception as e:
190
+ return f"Transcription failed: {e}"
191
+
192
+ # --- Gradio UI ---
193
+ with gr.Blocks() as demo:
194
+ gr.Markdown("# Qwen2.5 Chatbot with Voice Input/Output")
195
+
196
+ with gr.Tab("Chat"):
197
+ chatbot = gr.Chatbot(type='messages')
198
+ text_input = gr.Textbox(placeholder="Type your message...")
199
+ audio_output = gr.Audio(label="Response Audio", autoplay=True)
200
+ text_input.submit(generate_response_and_audio, [text_input, chatbot], [chatbot, audio_output])
201
+
202
+ with gr.Tab("Transcribe"):
203
+ audio_input = gr.Audio(type="filepath", label="Upload Audio")
204
+ transcribed = gr.Textbox(label="Transcription")
205
+ audio_input.upload(transcribe_audio, audio_input, transcribed)
206
+
207
+ clear_btn = gr.Button("Clear All")
208
+ clear_btn.click(lambda: ([], "", None), None, [chatbot, text_input, audio_output])
209
+
210
+ demo.queue().launch()