Luigi commited on
Commit
d33dfcd
·
1 Parent(s): eb215ff

Add model caching

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -163,6 +163,13 @@ def validate_or_download_model(selected_model):
163
  st.stop()
164
  return result
165
 
 
 
 
 
 
 
 
166
  def stream_response(llm, messages, max_tokens, temperature, top_k, top_p, repeat_penalty, response_queue):
167
  """Stream the model response token-by-token."""
168
  final_text = ""
@@ -229,10 +236,7 @@ with st.sidebar:
229
  selected_model = MODELS[selected_model_name]
230
  if st.session_state.model_name != selected_model_name:
231
  with st.spinner("Loading selected model..."):
232
- if st.session_state.llm is not None:
233
- del st.session_state.llm
234
- gc.collect()
235
- st.session_state.llm = validate_or_download_model(selected_model)
236
  st.session_state.model_name = selected_model_name
237
 
238
  llm = st.session_state.llm
 
163
  st.stop()
164
  return result
165
 
166
+ # ------------------------------
167
+ # Caching the Model Loading
168
+ # ------------------------------
169
+ @st.cache_resource
170
+ def load_cached_model(selected_model):
171
+ return validate_or_download_model(selected_model)
172
+
173
  def stream_response(llm, messages, max_tokens, temperature, top_k, top_p, repeat_penalty, response_queue):
174
  """Stream the model response token-by-token."""
175
  final_text = ""
 
236
  selected_model = MODELS[selected_model_name]
237
  if st.session_state.model_name != selected_model_name:
238
  with st.spinner("Loading selected model..."):
239
+ st.session_state.llm = load_cached_model(selected_model)
 
 
 
240
  st.session_state.model_name = selected_model_name
241
 
242
  llm = st.session_state.llm