ddas commited on
Commit
3daf4c6
Β·
unverified Β·
1 Parent(s): 1fde2f1

Add detailed logging for model loading debug

Browse files
Files changed (1) hide show
  1. instruction_classifier.py +30 -3
instruction_classifier.py CHANGED
@@ -114,14 +114,21 @@ class InstructionClassifierSanitizer:
114
  if hf_hub_download is None:
115
  raise ImportError("huggingface_hub is not installed")
116
 
 
 
 
 
117
  # Use HF_TOKEN from environment for private repositories
118
  token = os.getenv('HF_TOKEN')
119
  if token:
120
  print(f"πŸ“₯ Downloading private model from {self.model_repo_id}...")
 
121
  else:
122
  print(f"πŸ“₯ Downloading public model from {self.model_repo_id}...")
 
123
 
124
  # Download the model file (returns file path, not model object)
 
125
  model_path = hf_hub_download(
126
  repo_id=self.model_repo_id,
127
  filename=self.model_filename,
@@ -130,15 +137,30 @@ class InstructionClassifierSanitizer:
130
  )
131
  print(f"βœ… Model file downloaded to: {model_path}")
132
 
 
 
 
 
133
  # Load the checkpoint from the downloaded file
 
134
  checkpoint = torch.load(model_path, map_location=self.device)
 
 
135
  self._load_model_weights(checkpoint)
136
  print(f"βœ… Model weights loaded from {self.model_repo_id}")
 
 
137
  except Exception as e:
138
- print(f"❌ Failed to download model from {self.model_repo_id}: {e}")
139
- print("Full error details:")
 
 
140
  import traceback
141
  traceback.print_exc()
 
 
 
 
142
  raise RuntimeError(f"Failed to download model from {self.model_repo_id}: {e}")
143
 
144
  def _load_model_weights(self, checkpoint):
@@ -461,9 +483,14 @@ def sanitize_tool_output(tool_output):
461
  Returns:
462
  Sanitized tool output with instruction content removed
463
  """
 
 
464
  sanitizer = get_sanitizer()
465
  if sanitizer is None:
466
  print("⚠️ Instruction classifier not available, returning original output")
467
  return tool_output
468
 
469
- return sanitizer.sanitize_tool_output(tool_output)
 
 
 
 
114
  if hf_hub_download is None:
115
  raise ImportError("huggingface_hub is not installed")
116
 
117
+ print(f"πŸš€ Starting model download from {self.model_repo_id}")
118
+ print(f" Device: {self.device}")
119
+ print(f" Model name: {self.model_name}")
120
+
121
  # Use HF_TOKEN from environment for private repositories
122
  token = os.getenv('HF_TOKEN')
123
  if token:
124
  print(f"πŸ“₯ Downloading private model from {self.model_repo_id}...")
125
+ print(f" Using HF_TOKEN: {token[:8]}...{token[-8:] if len(token) > 16 else 'short'}")
126
  else:
127
  print(f"πŸ“₯ Downloading public model from {self.model_repo_id}...")
128
+ print(" No HF_TOKEN found - using public access")
129
 
130
  # Download the model file (returns file path, not model object)
131
+ print(f" Downloading {self.model_filename}...")
132
  model_path = hf_hub_download(
133
  repo_id=self.model_repo_id,
134
  filename=self.model_filename,
 
137
  )
138
  print(f"βœ… Model file downloaded to: {model_path}")
139
 
140
+ # Check file size
141
+ file_size = os.path.getsize(model_path) / (1024**3) # GB
142
+ print(f" File size: {file_size:.2f} GB")
143
+
144
  # Load the checkpoint from the downloaded file
145
+ print("πŸ”„ Loading checkpoint into memory...")
146
  checkpoint = torch.load(model_path, map_location=self.device)
147
+ print(f" Checkpoint keys: {len(checkpoint.keys())}")
148
+
149
  self._load_model_weights(checkpoint)
150
  print(f"βœ… Model weights loaded from {self.model_repo_id}")
151
+ print(f" Model parameter count: {sum(p.numel() for p in self.model.parameters())}")
152
+
153
  except Exception as e:
154
+ print(f"❌ CRITICAL ERROR: Failed to download model from {self.model_repo_id}")
155
+ print(f" Error type: {type(e).__name__}")
156
+ print(f" Error message: {e}")
157
+ print(" Full error details:")
158
  import traceback
159
  traceback.print_exc()
160
+ print(" Environment info:")
161
+ print(f" HF_TOKEN set: {'Yes' if os.getenv('HF_TOKEN') else 'No'}")
162
+ print(f" Device: {self.device}")
163
+ print(f" PyTorch version: {torch.__version__}")
164
  raise RuntimeError(f"Failed to download model from {self.model_repo_id}: {e}")
165
 
166
  def _load_model_weights(self, checkpoint):
 
483
  Returns:
484
  Sanitized tool output with instruction content removed
485
  """
486
+ print(f"πŸ” sanitize_tool_output called with: {tool_output[:100]}...")
487
+
488
  sanitizer = get_sanitizer()
489
  if sanitizer is None:
490
  print("⚠️ Instruction classifier not available, returning original output")
491
  return tool_output
492
 
493
+ print("βœ… Sanitizer found, processing...")
494
+ result = sanitizer.sanitize_tool_output(tool_output)
495
+ print(f"πŸ”’ Sanitization complete, result: {result[:100]}...")
496
+ return result