willwade commited on
Commit
759b2cf
·
1 Parent(s): 8173918

Add progress indicators and model caching

Browse files
Files changed (2) hide show
  1. app.py +26 -2
  2. utils.py +11 -0
app.py CHANGED
@@ -103,11 +103,12 @@ def on_person_change(person_id):
103
  return context_info, phrases_text, topics
104
 
105
 
106
- def change_model(model_name):
107
  """Change the language model used for generation.
108
 
109
  Args:
110
  model_name: The name of the model to use
 
111
 
112
  Returns:
113
  A status message about the model change
@@ -120,12 +121,17 @@ def change_model(model_name):
120
  if model_name == suggestion_generator.model_name:
121
  return f"Already using model: {model_name}"
122
 
 
 
 
123
  # Try to load the new model
124
  success = suggestion_generator.load_model(model_name)
125
 
126
  if success:
 
127
  return f"Successfully switched to model: {model_name}"
128
  else:
 
129
  return f"Failed to load model: {model_name}. Using fallback responses instead."
130
 
131
 
@@ -136,6 +142,7 @@ def generate_suggestions(
136
  selected_topic=None,
137
  model_name="distilgpt2",
138
  temperature=0.7,
 
139
  ):
140
  """Generate suggestions based on the selected person and user input."""
141
  print(
@@ -144,13 +151,17 @@ def generate_suggestions(
144
  f"model={model_name}, temperature={temperature}"
145
  )
146
 
 
 
 
147
  if not person_id:
148
  print("No person_id provided")
149
  return "Please select who you're talking to first."
150
 
151
  # Make sure we're using the right model
152
  if model_name != suggestion_generator.model_name:
153
- change_model(model_name)
 
154
 
155
  person_context = social_graph.get_person_context(person_id)
156
  print(f"Person context: {person_context}")
@@ -206,9 +217,13 @@ def generate_suggestions(
206
  # If suggestion type is "model", use the language model for multiple suggestions
207
  if suggestion_type == "model":
208
  print("Using model for suggestions")
 
 
209
  # Generate 3 different suggestions
210
  suggestions = []
211
  for i in range(3):
 
 
212
  print(f"Generating suggestion {i+1}/3")
213
  try:
214
  suggestion = suggestion_generator.generate_suggestion(
@@ -247,9 +262,14 @@ def generate_suggestions(
247
  else:
248
  print("No category inferred, falling back to model")
249
  # Fall back to model if we couldn't infer a category
 
250
  try:
251
  suggestions = []
252
  for i in range(3):
 
 
 
 
253
  suggestion = suggestion_generator.generate_suggestion(
254
  person_context, user_input, temperature=temperature
255
  )
@@ -278,6 +298,10 @@ def generate_suggestions(
278
  result = "No suggestions available. Please try a different option."
279
 
280
  print(f"Returning result: {result[:100]}...")
 
 
 
 
281
  return result
282
 
283
 
 
103
  return context_info, phrases_text, topics
104
 
105
 
106
+ def change_model(model_name, progress=gr.Progress()):
107
  """Change the language model used for generation.
108
 
109
  Args:
110
  model_name: The name of the model to use
111
+ progress: Gradio progress indicator
112
 
113
  Returns:
114
  A status message about the model change
 
121
  if model_name == suggestion_generator.model_name:
122
  return f"Already using model: {model_name}"
123
 
124
+ # Show progress indicator
125
+ progress(0, desc=f"Loading model: {model_name}")
126
+
127
  # Try to load the new model
128
  success = suggestion_generator.load_model(model_name)
129
 
130
  if success:
131
+ progress(1.0, desc=f"Model loaded: {model_name}")
132
  return f"Successfully switched to model: {model_name}"
133
  else:
134
+ progress(1.0, desc="Model loading failed")
135
  return f"Failed to load model: {model_name}. Using fallback responses instead."
136
 
137
 
 
142
  selected_topic=None,
143
  model_name="distilgpt2",
144
  temperature=0.7,
145
+ progress=gr.Progress(),
146
  ):
147
  """Generate suggestions based on the selected person and user input."""
148
  print(
 
151
  f"model={model_name}, temperature={temperature}"
152
  )
153
 
154
+ # Initialize progress
155
+ progress(0, desc="Starting...")
156
+
157
  if not person_id:
158
  print("No person_id provided")
159
  return "Please select who you're talking to first."
160
 
161
  # Make sure we're using the right model
162
  if model_name != suggestion_generator.model_name:
163
+ progress(0.1, desc=f"Switching to model: {model_name}")
164
+ change_model(model_name, progress)
165
 
166
  person_context = social_graph.get_person_context(person_id)
167
  print(f"Person context: {person_context}")
 
217
  # If suggestion type is "model", use the language model for multiple suggestions
218
  if suggestion_type == "model":
219
  print("Using model for suggestions")
220
+ progress(0.2, desc="Preparing to generate suggestions...")
221
+
222
  # Generate 3 different suggestions
223
  suggestions = []
224
  for i in range(3):
225
+ progress_value = 0.3 + (i * 0.2) # Progress from 30% to 70%
226
+ progress(progress_value, desc=f"Generating suggestion {i+1}/3")
227
  print(f"Generating suggestion {i+1}/3")
228
  try:
229
  suggestion = suggestion_generator.generate_suggestion(
 
262
  else:
263
  print("No category inferred, falling back to model")
264
  # Fall back to model if we couldn't infer a category
265
+ progress(0.3, desc="No category detected, using model instead...")
266
  try:
267
  suggestions = []
268
  for i in range(3):
269
+ progress_value = 0.4 + (i * 0.15) # Progress from 40% to 70%
270
+ progress(
271
+ progress_value, desc=f"Generating fallback suggestion {i+1}/3"
272
+ )
273
  suggestion = suggestion_generator.generate_suggestion(
274
  person_context, user_input, temperature=temperature
275
  )
 
298
  result = "No suggestions available. Please try a different option."
299
 
300
  print(f"Returning result: {result[:100]}...")
301
+
302
+ # Complete the progress
303
+ progress(1.0, desc="Completed!")
304
+
305
  return result
306
 
307
 
utils.py CHANGED
@@ -161,6 +161,7 @@ class SuggestionGenerator:
161
  self.model_loaded = False
162
  self.generator = None
163
  self.aac_user_info = None
 
164
 
165
  # Load AAC user information from social graph
166
  try:
@@ -196,6 +197,13 @@ class SuggestionGenerator:
196
  self.model_name = model_name
197
  self.model_loaded = False
198
 
 
 
 
 
 
 
 
199
  try:
200
  print(f"Loading model: {model_name}")
201
 
@@ -258,6 +266,9 @@ class SuggestionGenerator:
258
  # For non-gated models, use the standard pipeline
259
  self.generator = pipeline("text-generation", model=model_name)
260
 
 
 
 
261
  self.model_loaded = True
262
  print(f"Model loaded successfully: {model_name}")
263
  return True
 
161
  self.model_loaded = False
162
  self.generator = None
163
  self.aac_user_info = None
164
+ self.loaded_models = {} # Cache for loaded models
165
 
166
  # Load AAC user information from social graph
167
  try:
 
197
  self.model_name = model_name
198
  self.model_loaded = False
199
 
200
+ # Check if model is already loaded in cache
201
+ if model_name in self.loaded_models:
202
+ print(f"Using cached model: {model_name}")
203
+ self.generator = self.loaded_models[model_name]
204
+ self.model_loaded = True
205
+ return True
206
+
207
  try:
208
  print(f"Loading model: {model_name}")
209
 
 
266
  # For non-gated models, use the standard pipeline
267
  self.generator = pipeline("text-generation", model=model_name)
268
 
269
+ # Cache the loaded model
270
+ self.loaded_models[model_name] = self.generator
271
+
272
  self.model_loaded = True
273
  print(f"Model loaded successfully: {model_name}")
274
  return True