Luigi commited on
Commit
6e8312c
·
1 Parent(s): cc91a1a

fault-free model loading

Browse files
Files changed (1) hide show
  1. app.py +29 -30
app.py CHANGED
@@ -63,22 +63,9 @@ if "model_name" not in st.session_state:
63
  if "llm" not in st.session_state:
64
  st.session_state.llm = None
65
 
66
- # Make sure models dir exists
67
  os.makedirs("models", exist_ok=True)
68
 
69
- # If the selected model file does not exist or is invalid, clean up and re-download
70
- def validate_or_download_model():
71
- if not os.path.exists(model_path):
72
- cleanup_old_models()
73
- download_model()
74
- return
75
- try:
76
- _ = Llama(model_path=model_path, n_ctx=16, n_threads=1) # dummy check
77
- except Exception as e:
78
- st.warning(f"Model file was invalid or corrupt: {e}\nRedownloading...")
79
- cleanup_old_models()
80
- download_model()
81
-
82
  def cleanup_old_models():
83
  for f in os.listdir("models"):
84
  if f.endswith(".gguf") and f != selected_model["filename"]:
@@ -96,28 +83,40 @@ def download_model():
96
  local_dir_use_symlinks=False,
97
  )
98
 
99
- validate_or_download_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # Load model if changed
102
  if st.session_state.model_name != selected_model_name:
103
  if st.session_state.llm is not None:
104
  del st.session_state.llm
105
  gc.collect()
106
- try:
107
- st.session_state.llm = Llama(
108
- model_path=model_path,
109
- n_ctx=1024,
110
- n_threads=2,
111
- n_threads_batch=2,
112
- n_batch=4,
113
- n_gpu_layers=0,
114
- use_mlock=False,
115
- use_mmap=True,
116
- verbose=False,
117
- )
118
- except Exception as e:
119
- st.error(f"Failed to load model: {e}")
120
- st.stop()
121
  st.session_state.model_name = selected_model_name
122
 
123
  llm = st.session_state.llm
 
63
  if "llm" not in st.session_state:
64
  st.session_state.llm = None
65
 
66
+ # Ensure model directory exists
67
  os.makedirs("models", exist_ok=True)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def cleanup_old_models():
70
  for f in os.listdir("models"):
71
  if f.endswith(".gguf") and f != selected_model["filename"]:
 
83
  local_dir_use_symlinks=False,
84
  )
85
 
86
+ def try_load_model(path):
87
+ try:
88
+ return Llama(model_path=path, n_ctx=1024, n_threads=2, n_threads_batch=2, n_batch=4, n_gpu_layers=0, use_mlock=False, use_mmap=True, verbose=False)
89
+ except Exception as e:
90
+ return str(e)
91
+
92
+ def validate_or_download_model():
93
+ if not os.path.exists(model_path):
94
+ cleanup_old_models()
95
+ download_model()
96
+
97
+ # First load attempt
98
+ result = try_load_model(model_path)
99
+ if isinstance(result, str):
100
+ st.warning(f"Initial load failed: {result}\nAttempting re-download...")
101
+ try:
102
+ os.remove(model_path)
103
+ except:
104
+ pass
105
+ cleanup_old_models()
106
+ download_model()
107
+ result = try_load_model(model_path)
108
+ if isinstance(result, str):
109
+ st.error(f"Model still failed after re-download: {result}")
110
+ st.stop()
111
+ return result
112
+ return result
113
 
114
  # Load model if changed
115
  if st.session_state.model_name != selected_model_name:
116
  if st.session_state.llm is not None:
117
  del st.session_state.llm
118
  gc.collect()
119
+ st.session_state.llm = validate_or_download_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  st.session_state.model_name = selected_model_name
121
 
122
  llm = st.session_state.llm