Add progress indicators and model caching
Browse files
app.py
CHANGED
@@ -103,11 +103,12 @@ def on_person_change(person_id):
|
|
103 |
return context_info, phrases_text, topics
|
104 |
|
105 |
|
106 |
-
def change_model(model_name):
|
107 |
"""Change the language model used for generation.
|
108 |
|
109 |
Args:
|
110 |
model_name: The name of the model to use
|
|
|
111 |
|
112 |
Returns:
|
113 |
A status message about the model change
|
@@ -120,12 +121,17 @@ def change_model(model_name):
|
|
120 |
if model_name == suggestion_generator.model_name:
|
121 |
return f"Already using model: {model_name}"
|
122 |
|
|
|
|
|
|
|
123 |
# Try to load the new model
|
124 |
success = suggestion_generator.load_model(model_name)
|
125 |
|
126 |
if success:
|
|
|
127 |
return f"Successfully switched to model: {model_name}"
|
128 |
else:
|
|
|
129 |
return f"Failed to load model: {model_name}. Using fallback responses instead."
|
130 |
|
131 |
|
@@ -136,6 +142,7 @@ def generate_suggestions(
|
|
136 |
selected_topic=None,
|
137 |
model_name="distilgpt2",
|
138 |
temperature=0.7,
|
|
|
139 |
):
|
140 |
"""Generate suggestions based on the selected person and user input."""
|
141 |
print(
|
@@ -144,13 +151,17 @@ def generate_suggestions(
|
|
144 |
f"model={model_name}, temperature={temperature}"
|
145 |
)
|
146 |
|
|
|
|
|
|
|
147 |
if not person_id:
|
148 |
print("No person_id provided")
|
149 |
return "Please select who you're talking to first."
|
150 |
|
151 |
# Make sure we're using the right model
|
152 |
if model_name != suggestion_generator.model_name:
|
153 |
-
|
|
|
154 |
|
155 |
person_context = social_graph.get_person_context(person_id)
|
156 |
print(f"Person context: {person_context}")
|
@@ -206,9 +217,13 @@ def generate_suggestions(
|
|
206 |
# If suggestion type is "model", use the language model for multiple suggestions
|
207 |
if suggestion_type == "model":
|
208 |
print("Using model for suggestions")
|
|
|
|
|
209 |
# Generate 3 different suggestions
|
210 |
suggestions = []
|
211 |
for i in range(3):
|
|
|
|
|
212 |
print(f"Generating suggestion {i+1}/3")
|
213 |
try:
|
214 |
suggestion = suggestion_generator.generate_suggestion(
|
@@ -247,9 +262,14 @@ def generate_suggestions(
|
|
247 |
else:
|
248 |
print("No category inferred, falling back to model")
|
249 |
# Fall back to model if we couldn't infer a category
|
|
|
250 |
try:
|
251 |
suggestions = []
|
252 |
for i in range(3):
|
|
|
|
|
|
|
|
|
253 |
suggestion = suggestion_generator.generate_suggestion(
|
254 |
person_context, user_input, temperature=temperature
|
255 |
)
|
@@ -278,6 +298,10 @@ def generate_suggestions(
|
|
278 |
result = "No suggestions available. Please try a different option."
|
279 |
|
280 |
print(f"Returning result: {result[:100]}...")
|
|
|
|
|
|
|
|
|
281 |
return result
|
282 |
|
283 |
|
|
|
103 |
return context_info, phrases_text, topics
|
104 |
|
105 |
|
106 |
+
def change_model(model_name, progress=gr.Progress()):
|
107 |
"""Change the language model used for generation.
|
108 |
|
109 |
Args:
|
110 |
model_name: The name of the model to use
|
111 |
+
progress: Gradio progress indicator
|
112 |
|
113 |
Returns:
|
114 |
A status message about the model change
|
|
|
121 |
if model_name == suggestion_generator.model_name:
|
122 |
return f"Already using model: {model_name}"
|
123 |
|
124 |
+
# Show progress indicator
|
125 |
+
progress(0, desc=f"Loading model: {model_name}")
|
126 |
+
|
127 |
# Try to load the new model
|
128 |
success = suggestion_generator.load_model(model_name)
|
129 |
|
130 |
if success:
|
131 |
+
progress(1.0, desc=f"Model loaded: {model_name}")
|
132 |
return f"Successfully switched to model: {model_name}"
|
133 |
else:
|
134 |
+
progress(1.0, desc="Model loading failed")
|
135 |
return f"Failed to load model: {model_name}. Using fallback responses instead."
|
136 |
|
137 |
|
|
|
142 |
selected_topic=None,
|
143 |
model_name="distilgpt2",
|
144 |
temperature=0.7,
|
145 |
+
progress=gr.Progress(),
|
146 |
):
|
147 |
"""Generate suggestions based on the selected person and user input."""
|
148 |
print(
|
|
|
151 |
f"model={model_name}, temperature={temperature}"
|
152 |
)
|
153 |
|
154 |
+
# Initialize progress
|
155 |
+
progress(0, desc="Starting...")
|
156 |
+
|
157 |
if not person_id:
|
158 |
print("No person_id provided")
|
159 |
return "Please select who you're talking to first."
|
160 |
|
161 |
# Make sure we're using the right model
|
162 |
if model_name != suggestion_generator.model_name:
|
163 |
+
progress(0.1, desc=f"Switching to model: {model_name}")
|
164 |
+
change_model(model_name, progress)
|
165 |
|
166 |
person_context = social_graph.get_person_context(person_id)
|
167 |
print(f"Person context: {person_context}")
|
|
|
217 |
# If suggestion type is "model", use the language model for multiple suggestions
|
218 |
if suggestion_type == "model":
|
219 |
print("Using model for suggestions")
|
220 |
+
progress(0.2, desc="Preparing to generate suggestions...")
|
221 |
+
|
222 |
# Generate 3 different suggestions
|
223 |
suggestions = []
|
224 |
for i in range(3):
|
225 |
+
progress_value = 0.3 + (i * 0.2) # Progress from 30% to 70%
|
226 |
+
progress(progress_value, desc=f"Generating suggestion {i+1}/3")
|
227 |
print(f"Generating suggestion {i+1}/3")
|
228 |
try:
|
229 |
suggestion = suggestion_generator.generate_suggestion(
|
|
|
262 |
else:
|
263 |
print("No category inferred, falling back to model")
|
264 |
# Fall back to model if we couldn't infer a category
|
265 |
+
progress(0.3, desc="No category detected, using model instead...")
|
266 |
try:
|
267 |
suggestions = []
|
268 |
for i in range(3):
|
269 |
+
progress_value = 0.4 + (i * 0.15) # Progress from 40% to 70%
|
270 |
+
progress(
|
271 |
+
progress_value, desc=f"Generating fallback suggestion {i+1}/3"
|
272 |
+
)
|
273 |
suggestion = suggestion_generator.generate_suggestion(
|
274 |
person_context, user_input, temperature=temperature
|
275 |
)
|
|
|
298 |
result = "No suggestions available. Please try a different option."
|
299 |
|
300 |
print(f"Returning result: {result[:100]}...")
|
301 |
+
|
302 |
+
# Complete the progress
|
303 |
+
progress(1.0, desc="Completed!")
|
304 |
+
|
305 |
return result
|
306 |
|
307 |
|
utils.py
CHANGED
@@ -161,6 +161,7 @@ class SuggestionGenerator:
|
|
161 |
self.model_loaded = False
|
162 |
self.generator = None
|
163 |
self.aac_user_info = None
|
|
|
164 |
|
165 |
# Load AAC user information from social graph
|
166 |
try:
|
@@ -196,6 +197,13 @@ class SuggestionGenerator:
|
|
196 |
self.model_name = model_name
|
197 |
self.model_loaded = False
|
198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
try:
|
200 |
print(f"Loading model: {model_name}")
|
201 |
|
@@ -258,6 +266,9 @@ class SuggestionGenerator:
|
|
258 |
# For non-gated models, use the standard pipeline
|
259 |
self.generator = pipeline("text-generation", model=model_name)
|
260 |
|
|
|
|
|
|
|
261 |
self.model_loaded = True
|
262 |
print(f"Model loaded successfully: {model_name}")
|
263 |
return True
|
|
|
161 |
self.model_loaded = False
|
162 |
self.generator = None
|
163 |
self.aac_user_info = None
|
164 |
+
self.loaded_models = {} # Cache for loaded models
|
165 |
|
166 |
# Load AAC user information from social graph
|
167 |
try:
|
|
|
197 |
self.model_name = model_name
|
198 |
self.model_loaded = False
|
199 |
|
200 |
+
# Check if model is already loaded in cache
|
201 |
+
if model_name in self.loaded_models:
|
202 |
+
print(f"Using cached model: {model_name}")
|
203 |
+
self.generator = self.loaded_models[model_name]
|
204 |
+
self.model_loaded = True
|
205 |
+
return True
|
206 |
+
|
207 |
try:
|
208 |
print(f"Loading model: {model_name}")
|
209 |
|
|
|
266 |
# For non-gated models, use the standard pipeline
|
267 |
self.generator = pipeline("text-generation", model=model_name)
|
268 |
|
269 |
+
# Cache the loaded model
|
270 |
+
self.loaded_models[model_name] = self.generator
|
271 |
+
|
272 |
self.model_loaded = True
|
273 |
print(f"Model loaded successfully: {model_name}")
|
274 |
return True
|