ajsbsd commited on
Commit
77557c0
·
verified ·
1 Parent(s): 589237f

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -187
app.py DELETED
@@ -1,187 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import (
4
- AutoTokenizer,
5
- AutoModelForCausalLM,
6
- SpeechT5Processor,
7
- SpeechT5ForTextToSpeech,
8
- SpeechT5HifiGan,
9
- WhisperProcessor,
10
- WhisperForConditionalGeneration
11
- )
12
- from datasets import load_dataset
13
- import os
14
- import spaces
15
- import tempfile
16
- import soundfile as sf
17
- import librosa
18
-
19
- # --- Configuration ---
20
- HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
21
- TORCH_DTYPE = torch.bfloat16
22
- MAX_NEW_TOKENS = 512
23
- DO_SAMPLE = True
24
- TEMPERATURE = 0.7
25
- TOP_K = 50
26
- TOP_P = 0.95
27
-
28
- TTS_MODEL_ID = "microsoft/speecht5_tts"
29
- TTS_VOCODER_ID = "microsoft/speecht5_hifigan"
30
- STT_MODEL_ID = "openai/whisper-small"
31
-
32
- # --- Global Variables ---
33
- tokenizer = None
34
- llm_model = None
35
- tts_processor = None
36
- tts_model = None
37
- tts_vocoder = None
38
- speaker_embeddings = None
39
- whisper_processor = None
40
- whisper_model = None
41
- first_load = True
42
-
43
- # --- Load Models Function ---
44
- @spaces.GPU
45
- def load_models():
46
- global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
47
- global whisper_processor, whisper_model
48
-
49
- if (tokenizer is not None and llm_model is not None and tts_model is not None and
50
- whisper_model is not None):
51
- print("All models already loaded.")
52
- return
53
-
54
- hf_token = os.environ.get("HF_TOKEN")
55
-
56
- # LLM
57
- try:
58
- tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token)
59
- if tokenizer.pad_token is None:
60
- tokenizer.pad_token = tokenizer.eos_token
61
- llm_model = AutoModelForCausalLM.from_pretrained(
62
- HUGGINGFACE_MODEL_ID,
63
- torch_dtype=TORCH_DTYPE,
64
- device_map="auto",
65
- token=hf_token
66
- ).eval()
67
- print("LLM loaded successfully.")
68
- except Exception as e:
69
- print(f"Error loading LLM: {e}")
70
-
71
- # TTS
72
- try:
73
- tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token)
74
- tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
75
- tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)
76
- embeddings = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
77
- speaker_embeddings = torch.tensor(embeddings[7306]["xvector"]).unsqueeze(0)
78
- device = llm_model.device if llm_model else 'cpu'
79
- tts_model.to(device)
80
- tts_vocoder.to(device)
81
- speaker_embeddings = speaker_embeddings.to(device)
82
- print("TTS models loaded.")
83
- except Exception as e:
84
- print(f"Error loading TTS: {e}")
85
-
86
- # STT
87
- try:
88
- whisper_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, token=hf_token)
89
- whisper_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID, token=hf_token)
90
- whisper_model.to(llm_model.device if llm_model else 'cpu')
91
- print("Whisper loaded.")
92
- except Exception as e:
93
- print(f"Error loading Whisper: {e}")
94
-
95
- # --- Generate Response + Audio ---
96
- @spaces.GPU
97
- def generate_response_and_audio(message, history):
98
- global first_load
99
- if first_load:
100
- load_models()
101
- first_load = False
102
-
103
- global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
104
-
105
- if tokenizer is None or llm_model is None:
106
- return [{"role": "assistant", "content": "Error: LLM not loaded."}], None
107
-
108
- messages = history.copy()
109
- messages.append({"role": "user", "content": message})
110
-
111
- try:
112
- input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
113
- except:
114
- input_text = ""
115
- for item in history:
116
- input_text += f"{item['role'].capitalize()}: {item['content']}\n"
117
- input_text += f"User: {message}\nAssistant:"
118
-
119
- try:
120
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(llm_model.device)
121
- output_ids = llm_model.generate(
122
- inputs["input_ids"],
123
- attention_mask=inputs["attention_mask"],
124
- max_new_tokens=MAX_NEW_TOKENS,
125
- do_sample=DO_SAMPLE,
126
- temperature=TEMPERATURE,
127
- top_k=TOP_K,
128
- top_p=TOP_P,
129
- pad_token_id=tokenizer.eos_token_id
130
- )
131
- generated_text = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip()
132
- except Exception as e:
133
- print(f"LLM error: {e}")
134
- return history + [{"role": "assistant", "content": "I had an issue generating a response."}], None
135
-
136
- audio_path = None
137
- if None not in [tts_processor, tts_model, tts_vocoder, speaker_embeddings]:
138
- try:
139
- tts_inputs = tts_processor(text=generated_text, return_tensors="pt", max_length=550, truncation=True).to(llm_model.device)
140
- speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
141
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
142
- audio_path = tmp_file.name
143
- sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
144
- except Exception as e:
145
- print(f"TTS error: {e}")
146
-
147
- return history + [{"role": "assistant", "content": generated_text}], audio_path
148
-
149
- # --- Transcribe Audio ---
150
- @spaces.GPU
151
- def transcribe_audio(filepath):
152
- global first_load
153
- if first_load:
154
- load_models()
155
- first_load = False
156
-
157
- global whisper_processor, whisper_model
158
- if whisper_model is None:
159
- return "Whisper model not loaded."
160
-
161
- try:
162
- audio, sr = librosa.load(filepath, sr=16000)
163
- inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt").input_features.to(whisper_model.device)
164
- outputs = whisper_model.generate(inputs)
165
- return whisper_processor.batch_decode(outputs, skip_special_tokens=True)[0]
166
- except Exception as e:
167
- return f"Transcription failed: {e}"
168
-
169
- # --- Gradio UI ---
170
- with gr.Blocks() as demo:
171
- gr.Markdown("# Qwen2.5 Chatbot with Voice Input/Output")
172
-
173
- with gr.Tab("Chat"):
174
- chatbot = gr.Chatbot(type='messages')
175
- text_input = gr.Textbox(placeholder="Type your message...")
176
- audio_output = gr.Audio(label="Response Audio", autoplay=True)
177
- text_input.submit(generate_response_and_audio, [text_input, chatbot], [chatbot, audio_output])
178
-
179
- with gr.Tab("Transcribe"):
180
- audio_input = gr.Audio(type="filepath", label="Upload Audio")
181
- transcribed = gr.Textbox(label="Transcription")
182
- audio_input.upload(transcribe_audio, audio_input, transcribed)
183
-
184
- clear_btn = gr.Button("Clear All")
185
- clear_btn.click(lambda: ([], "", None), None, [chatbot, text_input, audio_output])
186
-
187
- demo.queue().launch()