ajsbsd commited on
Commit
28ccf5e
·
verified ·
1 Parent(s): 07b797f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -65
app.py CHANGED
@@ -1,91 +1,137 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
  import os
5
- import spaces
 
 
6
 
7
- # --- Configuration ---
8
- # IMPORTANT: Replace with the path to your locally downloaded model or a Hugging Face model ID.
9
- # Examples:
10
- # LOCAL_MODEL_PATH = "/path/to/your/downloaded/qwen-1.5b-instruct"
11
- # HUGGINGFACE_MODEL_ID = "Qwen/Qwen1.5-1.8B-Chat" # For a smaller Qwen model for local testing
12
  HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
13
 
14
  # You might need to adjust TORCH_DTYPE based on your GPU and model support
15
  # torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs
16
- TORCH_DTYPE = torch.float16 # or torch.bfloat16 or torch.float32
 
17
 
18
- # Generation parameters (can be adjusted for different response styles)
19
  MAX_NEW_TOKENS = 512
20
  DO_SAMPLE = True
21
  TEMPERATURE = 0.7
22
  TOP_K = 50
23
  TOP_P = 0.95
24
 
 
 
 
 
25
  # --- Global variables for models and tokenizers ---
26
  tokenizer = None
27
- model = None
 
 
 
 
28
 
29
  # --- Load Models and Tokenizers Function ---
30
- @spaces.GPU
31
- def load_model_and_tokenizer():
32
  """
33
- Loads the language model and tokenizer from Hugging Face Hub or a local path.
34
- This function will be called once when the Gradio app starts up.
35
  """
36
- global tokenizer, model
37
 
38
- if tokenizer is not None and model is not None:
39
- print("Model and tokenizer already loaded.")
40
  return
41
 
42
- print(f"Loading tokenizer from: {HUGGINGFACE_MODEL_ID}")
 
 
 
 
 
 
43
  try:
44
- tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID)
45
  if tokenizer.pad_token is None:
46
  tokenizer.pad_token = tokenizer.eos_token
47
  print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")
48
 
49
- print(f"Loading model from: {HUGGINGFACE_MODEL_ID}...")
50
- model = AutoModelForCausalLM.from_pretrained(
51
  HUGGINGFACE_MODEL_ID,
52
  torch_dtype=TORCH_DTYPE,
53
- device_map="auto" # Automatically maps model to GPU if available, else CPU
 
54
  )
55
- model.eval() # Set model to evaluation mode
56
- print("Model loaded successfully.")
57
  except Exception as e:
58
- print(f"Error loading model or tokenizer: {e}")
59
- print("Please ensure the model ID is correct and you have an internet connection for initial download, or the local path is valid.")
60
  tokenizer = None
61
- model = None
62
- raise RuntimeError("Failed to load model. Check your model ID/path and internet connection.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
 
 
 
 
 
 
64
 
65
- # --- Generate Response Function ---
66
- @spaces.GPU
67
- def generate_response(
 
 
 
 
 
 
 
 
 
 
68
  message: str, # Current user message
69
  history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content')
70
- ) -> list: # Returns updated history for the Chatbot
71
  """
72
- Generates a text response from the loaded model based on user input and chat history.
 
73
  """
74
- global tokenizer, model
75
 
76
- # Initialize models if not already loaded
77
- if tokenizer is None or model is None:
78
- load_model_and_tokenizer()
79
 
80
- if tokenizer is None or model is None: # Check again in case loading failed
81
- # history.append([message, "Error: Chatbot model not loaded. Please check logs."])
82
- # For 'messages' type history, append a dictionary
83
  history.append({"role": "user", "content": message})
84
- history.append({"role": "assistant", "content": "Error: Chatbot model not loaded. Please check logs."})
85
- return history
86
 
87
- # Format messages for the model's chat template (e.g., for Instruct models)
88
- # The 'history' now directly contains dictionaries if type='messages' is used.
89
  messages = history # Use history directly as it's already in the correct format
90
  messages.append({"role": "user", "content": message}) # Add current user message
91
 
@@ -94,8 +140,7 @@ def generate_response(
94
  input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
95
  except Exception as e:
96
  print(f"Error applying chat template: {e}")
97
- # Fallback if chat template fails (e.g., for non-chat models)
98
- # Reconstruct input_text for models without explicit chat templates
99
  input_text = ""
100
  for item in history:
101
  if item["role"] == "user":
@@ -104,11 +149,11 @@ def generate_response(
104
  input_text += f"Assistant: {item['content']}\n"
105
  input_text += f"User: {message}\nAssistant:"
106
 
107
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
108
 
109
  # Generate response
110
  with torch.no_grad(): # Disable gradient calculations for inference
111
- output_ids = model.generate(
112
  input_ids,
113
  max_new_tokens=MAX_NEW_TOKENS,
114
  do_sample=DO_SAMPLE,
@@ -122,11 +167,45 @@ def generate_response(
122
  generated_token_ids = output_ids[0][input_ids.shape[-1]:]
123
  generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
124
 
125
- # --- Update Chat History ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  # Append the latest generated response to the history with its role
127
  history.append({"role": "assistant", "content": generated_text})
128
 
129
- return history
130
 
131
  # --- Gradio Interface ---
132
  with gr.Blocks() as demo:
@@ -147,34 +226,37 @@ with gr.Blocks() as demo:
147
  )
148
  submit_button = gr.Button("Send", scale=1)
149
 
 
 
 
 
 
 
150
  # Link the text input and button to the generation function
151
- # Note: 'inputs' will be current message and the full history (as 'messages' type)
152
- # 'outputs' will be the updated full history
153
  submit_button.click(
154
- fn=generate_response,
155
- inputs=[text_input, chatbot], # text_input is the new message, chatbot is the history
156
- outputs=[chatbot],
157
  queue=True # Queue requests for better concurrency
158
  )
159
  text_input.submit( # Also trigger on Enter key
160
- fn=generate_response,
161
  inputs=[text_input, chatbot],
162
- outputs=[chatbot],
163
  queue=True
164
  )
165
 
166
  # Clear button
167
  def clear_chat():
168
- # When type='messages', the clear function should return an empty list for history
169
- # and an empty string for the text input.
170
- return [], ""
171
  clear_button = gr.Button("Clear Chat")
172
- clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input])
173
 
174
 
175
- # Load the model when the app starts. This will ensure it's ready when the first request comes in.
176
- load_model_and_tokenizer()
177
 
178
  # Launch the Gradio app
179
- #demo.queue().launch() # For local development, use launch()
180
- demo.queue().launch(server_name="0.0.0.0")
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
4
+ from datasets import load_dataset # To get a speaker embedding for TTS
5
  import os
6
+ import spaces # Import the spaces library for GPU decorator
7
+ import tempfile # For creating temporary audio files
8
+ import soundfile as sf # To save audio files
9
 
10
+ # --- Configuration for Language Model (LLM) ---
11
+ # IMPORTANT: When deploying to Hugging Face Spaces, it's best to use the Hugging Face model ID
12
+ # rather than a local path ('.'), as the Space will fetch it from the Hub.
 
 
13
  HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd"
14
 
15
  # You might need to adjust TORCH_DTYPE based on your GPU and model support
16
  # torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs
17
+ # For ZeroGPU/H200, bfloat16 is often preferred if the model supports it and GPU allows.
18
+ TORCH_DTYPE = torch.bfloat16 # Use bfloat16 for optimal H200 performance
19
 
20
+ # Generation parameters for the LLM (can be adjusted for different response styles)
21
  MAX_NEW_TOKENS = 512
22
  DO_SAMPLE = True
23
  TEMPERATURE = 0.7
24
  TOP_K = 50
25
  TOP_P = 0.95
26
 
27
+ # --- Configuration for Text-to-Speech (TTS) ---
28
+ TTS_MODEL_ID = "microsoft/speecht5_tts"
29
+ TTS_VOCODER_ID = "microsoft/speecht5_hifigan"
30
+
31
  # --- Global variables for models and tokenizers ---
32
  tokenizer = None
33
+ llm_model = None # Renamed to avoid conflict with tts_model
34
+ tts_processor = None
35
+ tts_model = None
36
+ tts_vocoder = None
37
+ speaker_embeddings = None # Global for TTS speaker embedding
38
 
39
  # --- Load Models and Tokenizers Function ---
40
+ @spaces.GPU # Decorate with @spaces.GPU to signal this function needs GPU access
41
+ def load_models():
42
  """
43
+ Loads the language model, tokenizer, TTS models, and speaker embeddings
44
+ from Hugging Face Hub. This function will be called once when the Gradio app starts up.
45
  """
46
+ global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
47
 
48
+ if tokenizer is not None and llm_model is not None and tts_model is not None:
49
+ print("All models and tokenizers already loaded.")
50
  return
51
 
52
+ # When deploying to HF Spaces, you generally don't need an explicit HF_TOKEN
53
+ # for public models, but it's good practice for private models or if
54
+ # rate limits are hit.
55
+ hf_token = os.environ.get("HF_TOKEN") # Access HF_TOKEN from Space secrets if set
56
+
57
+ # Load Language Model (LLM)
58
+ print(f"Loading LLM tokenizer from: {HUGGINGFACE_MODEL_ID}")
59
  try:
60
+ tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID, token=hf_token)
61
  if tokenizer.pad_token is None:
62
  tokenizer.pad_token = tokenizer.eos_token
63
  print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})")
64
 
65
+ print(f"Loading LLM model from: {HUGGINGFACE_MODEL_ID}...")
66
+ llm_model = AutoModelForCausalLM.from_pretrained(
67
  HUGGINGFACE_MODEL_ID,
68
  torch_dtype=TORCH_DTYPE,
69
+ device_map="auto", # Automatically maps model to GPU if available, else CPU
70
+ token=hf_token # Pass token if loading private model
71
  )
72
+ llm_model.eval() # Set model to evaluation mode
73
+ print("LLM model loaded successfully.")
74
  except Exception as e:
75
+ print(f"Error loading LLM model or tokenizer: {e}")
76
+ print("Please ensure the LLM model ID is correct and you have an internet connection for initial download, or the local path is valid.")
77
  tokenizer = None
78
+ llm_model = None
79
+ raise RuntimeError("Failed to load LLM model. Check your model ID/path and internet connection.")
80
+
81
+ # Load TTS models
82
+ print(f"Loading TTS processor, model, and vocoder from: {TTS_MODEL_ID}, {TTS_VOCODER_ID}")
83
+ try:
84
+ tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL_ID, token=hf_token)
85
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL_ID, token=hf_token)
86
+ tts_vocoder = SpeechT5HifiGan.from_pretrained(TTS_VOCODER_ID, token=hf_token)
87
+
88
+ # Load a speaker embedding (essential for SpeechT5 TTS)
89
+ # Using a sample from a public dataset for demonstration
90
+ print("Loading speaker embeddings for TTS...")
91
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", token=hf_token)
92
+ # Using a specific speaker embedding (you can experiment with different indices)
93
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
94
 
95
+ # Move TTS components to the same device as the LLM model
96
+ device = llm_model.device if llm_model else 'cpu'
97
+ tts_model.to(device)
98
+ tts_vocoder.to(device)
99
+ speaker_embeddings = speaker_embeddings.to(device)
100
+ print(f"TTS models and speaker embeddings loaded successfully to device: {device}.")
101
 
102
+ except Exception as e:
103
+ print(f"Error loading TTS models or speaker embeddings: {e}")
104
+ print("Please ensure TTS model IDs are correct and you have an internet connection.")
105
+ tts_processor = None
106
+ tts_model = None
107
+ tts_vocoder = None
108
+ speaker_embeddings = None
109
+ raise RuntimeError("Failed to load TTS components. Check model IDs and internet connection.")
110
+
111
+
112
+ # --- Generate Response and Audio Function ---
113
+ @spaces.GPU # Decorate with @spaces.GPU as this function performs GPU-intensive inference
114
+ def generate_response_and_audio(
115
  message: str, # Current user message
116
  history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content')
117
+ ) -> tuple: # Returns (updated_history, audio_file_path)
118
  """
119
+ Generates a text response from the loaded LLM and then converts it to audio
120
+ using the loaded TTS model.
121
  """
122
+ global tokenizer, llm_model, tts_processor, tts_model, tts_vocoder, speaker_embeddings
123
 
124
+ # Initialize all models if not already loaded
125
+ if tokenizer is None or llm_model is None or tts_model is None:
126
+ load_models()
127
 
128
+ if tokenizer is None or llm_model is None: # Check LLM loading status
 
 
129
  history.append({"role": "user", "content": message})
130
+ history.append({"role": "assistant", "content": "Error: Chatbot LLM not loaded. Please check logs."})
131
+ return history, None
132
 
133
+ # --- 1. Generate Text Response (LLM) ---
134
+ # Format messages for the model's chat template
135
  messages = history # Use history directly as it's already in the correct format
136
  messages.append({"role": "user", "content": message}) # Add current user message
137
 
 
140
  input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
141
  except Exception as e:
142
  print(f"Error applying chat template: {e}")
143
+ # Fallback for models without explicit chat templates
 
144
  input_text = ""
145
  for item in history:
146
  if item["role"] == "user":
 
149
  input_text += f"Assistant: {item['content']}\n"
150
  input_text += f"User: {message}\nAssistant:"
151
 
152
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(llm_model.device)
153
 
154
  # Generate response
155
  with torch.no_grad(): # Disable gradient calculations for inference
156
+ output_ids = llm_model.generate(
157
  input_ids,
158
  max_new_tokens=MAX_NEW_TOKENS,
159
  do_sample=DO_SAMPLE,
 
167
  generated_token_ids = output_ids[0][input_ids.shape[-1]:]
168
  generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip()
169
 
170
+ # --- 2. Generate Audio from Response (TTS) ---
171
+ audio_path = None
172
+ if tts_processor and tts_model and tts_vocoder and speaker_embeddings is not None:
173
+ try:
174
+ # Ensure TTS components are on the correct device
175
+ device = llm_model.device if llm_model else 'cpu'
176
+ tts_model.to(device)
177
+ tts_vocoder.to(device)
178
+ speaker_embeddings = speaker_embeddings.to(device)
179
+
180
+ tts_inputs = tts_processor(
181
+ text=generated_text,
182
+ return_tensors="pt",
183
+ max_length=550, # Set a max length to prevent excessively long audio
184
+ truncation=True # Enable truncation if text exceeds max_length
185
+ ).to(device)
186
+
187
+ with torch.no_grad():
188
+ speech = tts_model.generate_speech(tts_inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder)
189
+
190
+ # Create a temporary file to save the audio
191
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
192
+ audio_path = tmp_file.name
193
+ # Ensure audio data is on CPU before saving with soundfile
194
+ sf.write(audio_path, speech.cpu().numpy(), samplerate=16000)
195
+ print(f"Audio saved to: {audio_path}")
196
+
197
+ except Exception as e:
198
+ print(f"Error generating audio: {e}")
199
+ audio_path = None # Return None if audio generation fails
200
+ else:
201
+ print("TTS components not loaded. Skipping audio generation.")
202
+
203
+
204
+ # --- 3. Update Chat History ---
205
  # Append the latest generated response to the history with its role
206
  history.append({"role": "assistant", "content": generated_text})
207
 
208
+ return history, audio_path
209
 
210
  # --- Gradio Interface ---
211
  with gr.Blocks() as demo:
 
226
  )
227
  submit_button = gr.Button("Send", scale=1)
228
 
229
+ audio_output = gr.Audio(
230
+ label="Listen to Response",
231
+ autoplay=True, # Automatically play audio
232
+ interactive=False # Don't allow user to interact with this audio component
233
+ )
234
+
235
  # Link the text input and button to the generation function
236
+ # Outputs now include both the chatbot history and the audio file path
 
237
  submit_button.click(
238
+ fn=generate_response_and_audio,
239
+ inputs=[text_input, chatbot],
240
+ outputs=[chatbot, audio_output],
241
  queue=True # Queue requests for better concurrency
242
  )
243
  text_input.submit( # Also trigger on Enter key
244
+ fn=generate_response_and_audio,
245
  inputs=[text_input, chatbot],
246
+ outputs=[chatbot, audio_output],
247
  queue=True
248
  )
249
 
250
  # Clear button
251
  def clear_chat():
252
+ # Clear history, text input, and audio output
253
+ return [], "", None
 
254
  clear_button = gr.Button("Clear Chat")
255
+ clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input, audio_output])
256
 
257
 
258
+ # Load all models when the app starts up
259
+ load_models()
260
 
261
  # Launch the Gradio app
262
+ demo.queue().launch()