migrate to qt models
Browse files- app.py +4 -4
- requirements.txt +2 -0
- utils.py +36 -3
app.py
CHANGED
@@ -22,8 +22,8 @@ AVAILABLE_MODELS = {
|
|
22 |
# Initialize the social graph manager
|
23 |
social_graph = SocialGraphManager("social_graph.json")
|
24 |
|
25 |
-
# Initialize the suggestion generator with Gemma 3
|
26 |
-
suggestion_generator = SuggestionGenerator("google/gemma-3-
|
27 |
|
28 |
# Test the model to make sure it's working
|
29 |
test_result = suggestion_generator.test_model()
|
@@ -153,7 +153,7 @@ def generate_suggestions(
|
|
153 |
user_input,
|
154 |
suggestion_type,
|
155 |
selected_topic=None,
|
156 |
-
model_name="google/gemma-3-
|
157 |
temperature=0.7,
|
158 |
mood=3,
|
159 |
progress=gr.Progress(),
|
@@ -462,7 +462,7 @@ with gr.Blocks(title="Will's AAC Communication Aid", css="custom.css") as demo:
|
|
462 |
with gr.Row():
|
463 |
model_dropdown = gr.Dropdown(
|
464 |
choices=list(AVAILABLE_MODELS.keys()),
|
465 |
-
value="google/gemma-3-
|
466 |
label="Language Model",
|
467 |
info="Select which AI model to use for generating responses",
|
468 |
)
|
|
|
22 |
# Initialize the social graph manager
|
23 |
social_graph = SocialGraphManager("social_graph.json")
|
24 |
|
25 |
+
# Initialize the suggestion generator with Gemma 3 1B (default - smaller model to save memory)
|
26 |
+
suggestion_generator = SuggestionGenerator("google/gemma-3-1b-it")
|
27 |
|
28 |
# Test the model to make sure it's working
|
29 |
test_result = suggestion_generator.test_model()
|
|
|
153 |
user_input,
|
154 |
suggestion_type,
|
155 |
selected_topic=None,
|
156 |
+
model_name="google/gemma-3-1b-it",
|
157 |
temperature=0.7,
|
158 |
mood=3,
|
159 |
progress=gr.Progress(),
|
|
|
462 |
with gr.Row():
|
463 |
model_dropdown = gr.Dropdown(
|
464 |
choices=list(AVAILABLE_MODELS.keys()),
|
465 |
+
value="google/gemma-3-1b-it",
|
466 |
label="Language Model",
|
467 |
info="Select which AI model to use for generating responses",
|
468 |
)
|
requirements.txt
CHANGED
@@ -4,3 +4,5 @@ sentence-transformers>=2.2.2
|
|
4 |
torch>=2.0.0
|
5 |
numpy>=1.24.0
|
6 |
openai-whisper>=20231117
|
|
|
|
|
|
4 |
torch>=2.0.0
|
5 |
numpy>=1.24.0
|
6 |
openai-whisper>=20231117
|
7 |
+
bitsandbytes>=0.41.0
|
8 |
+
accelerate>=0.21.0
|
utils.py
CHANGED
@@ -216,6 +216,8 @@ class SuggestionGenerator:
|
|
216 |
if is_gated_model:
|
217 |
# Try to get token from environment
|
218 |
import os
|
|
|
|
|
219 |
|
220 |
token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get(
|
221 |
"HF_TOKEN"
|
@@ -231,14 +233,31 @@ class SuggestionGenerator:
|
|
231 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
232 |
|
233 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
tokenizer = AutoTokenizer.from_pretrained(
|
235 |
model_name, token=token
|
236 |
)
|
|
|
|
|
237 |
model = AutoModelForCausalLM.from_pretrained(
|
238 |
-
model_name,
|
|
|
|
|
|
|
239 |
)
|
|
|
240 |
self.generator = pipeline(
|
241 |
-
"text-generation",
|
|
|
|
|
|
|
242 |
)
|
243 |
except Exception as e:
|
244 |
print(f"Error loading gated model with token: {e}")
|
@@ -248,7 +267,21 @@ class SuggestionGenerator:
|
|
248 |
print(
|
249 |
"Please visit the model page on Hugging Face Hub and accept the license."
|
250 |
)
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
else:
|
253 |
print("No Hugging Face token found in environment variables.")
|
254 |
print(
|
|
|
216 |
if is_gated_model:
|
217 |
# Try to get token from environment
|
218 |
import os
|
219 |
+
import torch
|
220 |
+
from transformers import BitsAndBytesConfig
|
221 |
|
222 |
token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get(
|
223 |
"HF_TOKEN"
|
|
|
233 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
234 |
|
235 |
try:
|
236 |
+
# Configure 4-bit quantization to save memory
|
237 |
+
quantization_config = BitsAndBytesConfig(
|
238 |
+
load_in_4bit=True,
|
239 |
+
bnb_4bit_compute_dtype=torch.float16,
|
240 |
+
bnb_4bit_quant_type="nf4",
|
241 |
+
bnb_4bit_use_double_quant=True,
|
242 |
+
)
|
243 |
+
|
244 |
tokenizer = AutoTokenizer.from_pretrained(
|
245 |
model_name, token=token
|
246 |
)
|
247 |
+
|
248 |
+
# Load model with quantization
|
249 |
model = AutoModelForCausalLM.from_pretrained(
|
250 |
+
model_name,
|
251 |
+
token=token,
|
252 |
+
quantization_config=quantization_config,
|
253 |
+
device_map="auto",
|
254 |
)
|
255 |
+
|
256 |
self.generator = pipeline(
|
257 |
+
"text-generation",
|
258 |
+
model=model,
|
259 |
+
tokenizer=tokenizer,
|
260 |
+
torch_dtype=torch.float16,
|
261 |
)
|
262 |
except Exception as e:
|
263 |
print(f"Error loading gated model with token: {e}")
|
|
|
267 |
print(
|
268 |
"Please visit the model page on Hugging Face Hub and accept the license."
|
269 |
)
|
270 |
+
# Try loading without quantization as fallback
|
271 |
+
try:
|
272 |
+
print("Trying to load model without quantization...")
|
273 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
274 |
+
model_name, token=token
|
275 |
+
)
|
276 |
+
model = AutoModelForCausalLM.from_pretrained(
|
277 |
+
model_name, token=token
|
278 |
+
)
|
279 |
+
self.generator = pipeline(
|
280 |
+
"text-generation", model=model, tokenizer=tokenizer
|
281 |
+
)
|
282 |
+
except Exception as e2:
|
283 |
+
print(f"Fallback loading also failed: {e2}")
|
284 |
+
raise e
|
285 |
else:
|
286 |
print("No Hugging Face token found in environment variables.")
|
287 |
print(
|