willwade commited on
Commit
238c097
·
1 Parent(s): 35169ba

migrate to qt models

Browse files
Files changed (3) hide show
  1. app.py +4 -4
  2. requirements.txt +2 -0
  3. utils.py +36 -3
app.py CHANGED
@@ -22,8 +22,8 @@ AVAILABLE_MODELS = {
22
  # Initialize the social graph manager
23
  social_graph = SocialGraphManager("social_graph.json")
24
 
25
- # Initialize the suggestion generator with Gemma 3 4B (default)
26
- suggestion_generator = SuggestionGenerator("google/gemma-3-4b-it")
27
 
28
  # Test the model to make sure it's working
29
  test_result = suggestion_generator.test_model()
@@ -153,7 +153,7 @@ def generate_suggestions(
153
  user_input,
154
  suggestion_type,
155
  selected_topic=None,
156
- model_name="google/gemma-3-4b-it",
157
  temperature=0.7,
158
  mood=3,
159
  progress=gr.Progress(),
@@ -462,7 +462,7 @@ with gr.Blocks(title="Will's AAC Communication Aid", css="custom.css") as demo:
462
  with gr.Row():
463
  model_dropdown = gr.Dropdown(
464
  choices=list(AVAILABLE_MODELS.keys()),
465
- value="google/gemma-3-4b-it",
466
  label="Language Model",
467
  info="Select which AI model to use for generating responses",
468
  )
 
22
  # Initialize the social graph manager
23
  social_graph = SocialGraphManager("social_graph.json")
24
 
25
+ # Initialize the suggestion generator with Gemma 3 1B (default - smaller model to save memory)
26
+ suggestion_generator = SuggestionGenerator("google/gemma-3-1b-it")
27
 
28
  # Test the model to make sure it's working
29
  test_result = suggestion_generator.test_model()
 
153
  user_input,
154
  suggestion_type,
155
  selected_topic=None,
156
+ model_name="google/gemma-3-1b-it",
157
  temperature=0.7,
158
  mood=3,
159
  progress=gr.Progress(),
 
462
  with gr.Row():
463
  model_dropdown = gr.Dropdown(
464
  choices=list(AVAILABLE_MODELS.keys()),
465
+ value="google/gemma-3-1b-it",
466
  label="Language Model",
467
  info="Select which AI model to use for generating responses",
468
  )
requirements.txt CHANGED
@@ -4,3 +4,5 @@ sentence-transformers>=2.2.2
4
  torch>=2.0.0
5
  numpy>=1.24.0
6
  openai-whisper>=20231117
 
 
 
4
  torch>=2.0.0
5
  numpy>=1.24.0
6
  openai-whisper>=20231117
7
+ bitsandbytes>=0.41.0
8
+ accelerate>=0.21.0
utils.py CHANGED
@@ -216,6 +216,8 @@ class SuggestionGenerator:
216
  if is_gated_model:
217
  # Try to get token from environment
218
  import os
 
 
219
 
220
  token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get(
221
  "HF_TOKEN"
@@ -231,14 +233,31 @@ class SuggestionGenerator:
231
  from transformers import AutoTokenizer, AutoModelForCausalLM
232
 
233
  try:
 
 
 
 
 
 
 
 
234
  tokenizer = AutoTokenizer.from_pretrained(
235
  model_name, token=token
236
  )
 
 
237
  model = AutoModelForCausalLM.from_pretrained(
238
- model_name, token=token
 
 
 
239
  )
 
240
  self.generator = pipeline(
241
- "text-generation", model=model, tokenizer=tokenizer
 
 
 
242
  )
243
  except Exception as e:
244
  print(f"Error loading gated model with token: {e}")
@@ -248,7 +267,21 @@ class SuggestionGenerator:
248
  print(
249
  "Please visit the model page on Hugging Face Hub and accept the license."
250
  )
251
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  else:
253
  print("No Hugging Face token found in environment variables.")
254
  print(
 
216
  if is_gated_model:
217
  # Try to get token from environment
218
  import os
219
+ import torch
220
+ from transformers import BitsAndBytesConfig
221
 
222
  token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get(
223
  "HF_TOKEN"
 
233
  from transformers import AutoTokenizer, AutoModelForCausalLM
234
 
235
  try:
236
+ # Configure 4-bit quantization to save memory
237
+ quantization_config = BitsAndBytesConfig(
238
+ load_in_4bit=True,
239
+ bnb_4bit_compute_dtype=torch.float16,
240
+ bnb_4bit_quant_type="nf4",
241
+ bnb_4bit_use_double_quant=True,
242
+ )
243
+
244
  tokenizer = AutoTokenizer.from_pretrained(
245
  model_name, token=token
246
  )
247
+
248
+ # Load model with quantization
249
  model = AutoModelForCausalLM.from_pretrained(
250
+ model_name,
251
+ token=token,
252
+ quantization_config=quantization_config,
253
+ device_map="auto",
254
  )
255
+
256
  self.generator = pipeline(
257
+ "text-generation",
258
+ model=model,
259
+ tokenizer=tokenizer,
260
+ torch_dtype=torch.float16,
261
  )
262
  except Exception as e:
263
  print(f"Error loading gated model with token: {e}")
 
267
  print(
268
  "Please visit the model page on Hugging Face Hub and accept the license."
269
  )
270
+ # Try loading without quantization as fallback
271
+ try:
272
+ print("Trying to load model without quantization...")
273
+ tokenizer = AutoTokenizer.from_pretrained(
274
+ model_name, token=token
275
+ )
276
+ model = AutoModelForCausalLM.from_pretrained(
277
+ model_name, token=token
278
+ )
279
+ self.generator = pipeline(
280
+ "text-generation", model=model, tokenizer=tokenizer
281
+ )
282
+ except Exception as e2:
283
+ print(f"Fallback loading also failed: {e2}")
284
+ raise e
285
  else:
286
  print("No Hugging Face token found in environment variables.")
287
  print(