ddas commited on
Commit
040a4cc
Β·
unverified Β·
1 Parent(s): 4669bb8
Files changed (1) hide show
  1. instruction_classifier.py +25 -7
instruction_classifier.py CHANGED
@@ -76,8 +76,11 @@ class InstructionClassifierSanitizer:
76
  self.model_repo_id = model_repo_id
77
  self.model_filename = model_filename
78
 
79
- # Initialize device
80
- self.device = get_device()
 
 
 
81
 
82
  # Map friendly names to actual model names
83
  model_mapping = {
@@ -103,9 +106,10 @@ class InstructionClassifierSanitizer:
103
  model_path = "models/best_instruction_classifier.pth"
104
 
105
  if os.path.exists(model_path):
106
- checkpoint = torch.load(model_path, map_location=self.device)
107
  self._load_model_weights(checkpoint)
108
  print(f"βœ… Loaded instruction classifier model from {model_path}")
 
109
  else:
110
  raise FileNotFoundError(f"Model file not found: {model_path}")
111
  else:
@@ -141,14 +145,15 @@ class InstructionClassifierSanitizer:
141
  file_size = os.path.getsize(model_path) / (1024**3) # GB
142
  print(f" File size: {file_size:.2f} GB")
143
 
144
- # Load the checkpoint from the downloaded file
145
  print("πŸ”„ Loading checkpoint into memory...")
146
- checkpoint = torch.load(model_path, map_location=self.device)
147
  print(f" Checkpoint keys: {len(checkpoint.keys())}")
148
 
149
  self._load_model_weights(checkpoint)
150
  print(f"βœ… Model weights loaded from {self.model_repo_id}")
151
  print(f" Model parameter count: {sum(p.numel() for p in self.model.parameters())}")
 
152
 
153
  except Exception as e:
154
  print(f"❌ CRITICAL ERROR: Failed to download model from {self.model_repo_id}")
@@ -171,9 +176,9 @@ class InstructionClassifierSanitizer:
171
  if not key.startswith('loss_fct'): # Skip loss function weights
172
  model_state_dict[key] = value
173
 
174
- # Load the filtered state dict
175
  self.model.load_state_dict(model_state_dict, strict=False)
176
- self.model.to(self.device)
177
  self.model.eval()
178
 
179
  @spaces.GPU
@@ -190,6 +195,12 @@ class InstructionClassifierSanitizer:
190
  if not tool_output or not tool_output.strip():
191
  return tool_output
192
 
 
 
 
 
 
 
193
  try:
194
  # Step 1: Detect if the tool output contains instructions
195
  is_injection, confidence_score, tagged_text = self._detect_injection(tool_output)
@@ -219,6 +230,7 @@ class InstructionClassifierSanitizer:
219
  # Return original output if sanitization fails
220
  return tool_output
221
 
 
222
  def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]]]:
223
  """
224
  Sanitization function that also returns annotation data for flagged content.
@@ -233,6 +245,12 @@ class InstructionClassifierSanitizer:
233
  if not tool_output or not tool_output.strip():
234
  return tool_output, []
235
 
 
 
 
 
 
 
236
  try:
237
  # Step 1: Detect if the tool output contains instructions
238
  is_injection, confidence_score, tagged_text = self._detect_injection(tool_output)
 
76
  self.model_repo_id = model_repo_id
77
  self.model_filename = model_filename
78
 
79
+ # Initialize device - always use CPU for initialization in ZeroGPU environments
80
+ # GPU operations will be handled within @spaces.GPU decorated methods
81
+ self.device = torch.device('cpu')
82
+ self.target_device = get_device() # Store target device for later use
83
+ print(f"πŸ”§ Device configuration: init_device={self.device}, target_device={self.target_device}")
84
 
85
  # Map friendly names to actual model names
86
  model_mapping = {
 
106
  model_path = "models/best_instruction_classifier.pth"
107
 
108
  if os.path.exists(model_path):
109
+ checkpoint = torch.load(model_path, map_location='cpu')
110
  self._load_model_weights(checkpoint)
111
  print(f"βœ… Loaded instruction classifier model from {model_path}")
112
+ print(f" Model loaded on {self.device} for ZeroGPU compatibility")
113
  else:
114
  raise FileNotFoundError(f"Model file not found: {model_path}")
115
  else:
 
145
  file_size = os.path.getsize(model_path) / (1024**3) # GB
146
  print(f" File size: {file_size:.2f} GB")
147
 
148
+ # Load the checkpoint from the downloaded file - always use CPU for ZeroGPU compatibility
149
  print("πŸ”„ Loading checkpoint into memory...")
150
+ checkpoint = torch.load(model_path, map_location='cpu')
151
  print(f" Checkpoint keys: {len(checkpoint.keys())}")
152
 
153
  self._load_model_weights(checkpoint)
154
  print(f"βœ… Model weights loaded from {self.model_repo_id}")
155
  print(f" Model parameter count: {sum(p.numel() for p in self.model.parameters())}")
156
+ print(f" Model loaded on {self.device} for ZeroGPU compatibility")
157
 
158
  except Exception as e:
159
  print(f"❌ CRITICAL ERROR: Failed to download model from {self.model_repo_id}")
 
176
  if not key.startswith('loss_fct'): # Skip loss function weights
177
  model_state_dict[key] = value
178
 
179
+ # Load the filtered state dict - keep on CPU for ZeroGPU compatibility
180
  self.model.load_state_dict(model_state_dict, strict=False)
181
+ self.model.to(self.device) # Keep on CPU during initialization
182
  self.model.eval()
183
 
184
  @spaces.GPU
 
195
  if not tool_output or not tool_output.strip():
196
  return tool_output
197
 
198
+ # Move model to target device (GPU) within @spaces.GPU decorated method
199
+ if self.device != self.target_device:
200
+ print(f"πŸš€ Moving model from {self.device} to {self.target_device} within @spaces.GPU context")
201
+ self.model.to(self.target_device)
202
+ self.device = self.target_device
203
+
204
  try:
205
  # Step 1: Detect if the tool output contains instructions
206
  is_injection, confidence_score, tagged_text = self._detect_injection(tool_output)
 
230
  # Return original output if sanitization fails
231
  return tool_output
232
 
233
+ @spaces.GPU
234
  def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]]]:
235
  """
236
  Sanitization function that also returns annotation data for flagged content.
 
245
  if not tool_output or not tool_output.strip():
246
  return tool_output, []
247
 
248
+ # Move model to target device (GPU) within @spaces.GPU decorated method
249
+ if self.device != self.target_device:
250
+ print(f"πŸš€ Moving model from {self.device} to {self.target_device} within @spaces.GPU context")
251
+ self.model.to(self.target_device)
252
+ self.device = self.target_device
253
+
254
  try:
255
  # Step 1: Detect if the tool output contains instructions
256
  is_injection, confidence_score, tagged_text = self._detect_injection(tool_output)