Spaces:
Running
on
Zero
Running
on
Zero
cuda fix
Browse files- instruction_classifier.py +25 -7
instruction_classifier.py
CHANGED
@@ -76,8 +76,11 @@ class InstructionClassifierSanitizer:
|
|
76 |
self.model_repo_id = model_repo_id
|
77 |
self.model_filename = model_filename
|
78 |
|
79 |
-
# Initialize device
|
80 |
-
|
|
|
|
|
|
|
81 |
|
82 |
# Map friendly names to actual model names
|
83 |
model_mapping = {
|
@@ -103,9 +106,10 @@ class InstructionClassifierSanitizer:
|
|
103 |
model_path = "models/best_instruction_classifier.pth"
|
104 |
|
105 |
if os.path.exists(model_path):
|
106 |
-
checkpoint = torch.load(model_path, map_location=
|
107 |
self._load_model_weights(checkpoint)
|
108 |
print(f"β
Loaded instruction classifier model from {model_path}")
|
|
|
109 |
else:
|
110 |
raise FileNotFoundError(f"Model file not found: {model_path}")
|
111 |
else:
|
@@ -141,14 +145,15 @@ class InstructionClassifierSanitizer:
|
|
141 |
file_size = os.path.getsize(model_path) / (1024**3) # GB
|
142 |
print(f" File size: {file_size:.2f} GB")
|
143 |
|
144 |
-
# Load the checkpoint from the downloaded file
|
145 |
print("π Loading checkpoint into memory...")
|
146 |
-
checkpoint = torch.load(model_path, map_location=
|
147 |
print(f" Checkpoint keys: {len(checkpoint.keys())}")
|
148 |
|
149 |
self._load_model_weights(checkpoint)
|
150 |
print(f"β
Model weights loaded from {self.model_repo_id}")
|
151 |
print(f" Model parameter count: {sum(p.numel() for p in self.model.parameters())}")
|
|
|
152 |
|
153 |
except Exception as e:
|
154 |
print(f"β CRITICAL ERROR: Failed to download model from {self.model_repo_id}")
|
@@ -171,9 +176,9 @@ class InstructionClassifierSanitizer:
|
|
171 |
if not key.startswith('loss_fct'): # Skip loss function weights
|
172 |
model_state_dict[key] = value
|
173 |
|
174 |
-
# Load the filtered state dict
|
175 |
self.model.load_state_dict(model_state_dict, strict=False)
|
176 |
-
self.model.to(self.device)
|
177 |
self.model.eval()
|
178 |
|
179 |
@spaces.GPU
|
@@ -190,6 +195,12 @@ class InstructionClassifierSanitizer:
|
|
190 |
if not tool_output or not tool_output.strip():
|
191 |
return tool_output
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
try:
|
194 |
# Step 1: Detect if the tool output contains instructions
|
195 |
is_injection, confidence_score, tagged_text = self._detect_injection(tool_output)
|
@@ -219,6 +230,7 @@ class InstructionClassifierSanitizer:
|
|
219 |
# Return original output if sanitization fails
|
220 |
return tool_output
|
221 |
|
|
|
222 |
def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]]]:
|
223 |
"""
|
224 |
Sanitization function that also returns annotation data for flagged content.
|
@@ -233,6 +245,12 @@ class InstructionClassifierSanitizer:
|
|
233 |
if not tool_output or not tool_output.strip():
|
234 |
return tool_output, []
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
try:
|
237 |
# Step 1: Detect if the tool output contains instructions
|
238 |
is_injection, confidence_score, tagged_text = self._detect_injection(tool_output)
|
|
|
76 |
self.model_repo_id = model_repo_id
|
77 |
self.model_filename = model_filename
|
78 |
|
79 |
+
# Initialize device - always use CPU for initialization in ZeroGPU environments
|
80 |
+
# GPU operations will be handled within @spaces.GPU decorated methods
|
81 |
+
self.device = torch.device('cpu')
|
82 |
+
self.target_device = get_device() # Store target device for later use
|
83 |
+
print(f"π§ Device configuration: init_device={self.device}, target_device={self.target_device}")
|
84 |
|
85 |
# Map friendly names to actual model names
|
86 |
model_mapping = {
|
|
|
106 |
model_path = "models/best_instruction_classifier.pth"
|
107 |
|
108 |
if os.path.exists(model_path):
|
109 |
+
checkpoint = torch.load(model_path, map_location='cpu')
|
110 |
self._load_model_weights(checkpoint)
|
111 |
print(f"β
Loaded instruction classifier model from {model_path}")
|
112 |
+
print(f" Model loaded on {self.device} for ZeroGPU compatibility")
|
113 |
else:
|
114 |
raise FileNotFoundError(f"Model file not found: {model_path}")
|
115 |
else:
|
|
|
145 |
file_size = os.path.getsize(model_path) / (1024**3) # GB
|
146 |
print(f" File size: {file_size:.2f} GB")
|
147 |
|
148 |
+
# Load the checkpoint from the downloaded file - always use CPU for ZeroGPU compatibility
|
149 |
print("π Loading checkpoint into memory...")
|
150 |
+
checkpoint = torch.load(model_path, map_location='cpu')
|
151 |
print(f" Checkpoint keys: {len(checkpoint.keys())}")
|
152 |
|
153 |
self._load_model_weights(checkpoint)
|
154 |
print(f"β
Model weights loaded from {self.model_repo_id}")
|
155 |
print(f" Model parameter count: {sum(p.numel() for p in self.model.parameters())}")
|
156 |
+
print(f" Model loaded on {self.device} for ZeroGPU compatibility")
|
157 |
|
158 |
except Exception as e:
|
159 |
print(f"β CRITICAL ERROR: Failed to download model from {self.model_repo_id}")
|
|
|
176 |
if not key.startswith('loss_fct'): # Skip loss function weights
|
177 |
model_state_dict[key] = value
|
178 |
|
179 |
+
# Load the filtered state dict - keep on CPU for ZeroGPU compatibility
|
180 |
self.model.load_state_dict(model_state_dict, strict=False)
|
181 |
+
self.model.to(self.device) # Keep on CPU during initialization
|
182 |
self.model.eval()
|
183 |
|
184 |
@spaces.GPU
|
|
|
195 |
if not tool_output or not tool_output.strip():
|
196 |
return tool_output
|
197 |
|
198 |
+
# Move model to target device (GPU) within @spaces.GPU decorated method
|
199 |
+
if self.device != self.target_device:
|
200 |
+
print(f"π Moving model from {self.device} to {self.target_device} within @spaces.GPU context")
|
201 |
+
self.model.to(self.target_device)
|
202 |
+
self.device = self.target_device
|
203 |
+
|
204 |
try:
|
205 |
# Step 1: Detect if the tool output contains instructions
|
206 |
is_injection, confidence_score, tagged_text = self._detect_injection(tool_output)
|
|
|
230 |
# Return original output if sanitization fails
|
231 |
return tool_output
|
232 |
|
233 |
+
@spaces.GPU
|
234 |
def sanitize_with_annotations(self, tool_output: str) -> Tuple[str, List[Dict[str, any]]]:
|
235 |
"""
|
236 |
Sanitization function that also returns annotation data for flagged content.
|
|
|
245 |
if not tool_output or not tool_output.strip():
|
246 |
return tool_output, []
|
247 |
|
248 |
+
# Move model to target device (GPU) within @spaces.GPU decorated method
|
249 |
+
if self.device != self.target_device:
|
250 |
+
print(f"π Moving model from {self.device} to {self.target_device} within @spaces.GPU context")
|
251 |
+
self.model.to(self.target_device)
|
252 |
+
self.device = self.target_device
|
253 |
+
|
254 |
try:
|
255 |
# Step 1: Detect if the tool output contains instructions
|
256 |
is_injection, confidence_score, tagged_text = self._detect_injection(tool_output)
|