Spaces:
Build error
Build error
Update model.py
Browse files
model.py
CHANGED
|
@@ -167,7 +167,7 @@ class MultiHeadClassification(nn.Module):
|
|
| 167 |
Returns:
|
| 168 |
None
|
| 169 |
"""
|
| 170 |
-
model = torch.load(path)
|
| 171 |
if head_name in self.heads:
|
| 172 |
num_classes = model['weight'].shape[0]
|
| 173 |
self.heads[head_name].load_state_dict(model)
|
|
@@ -209,7 +209,7 @@ class MultiHeadClassification(nn.Module):
|
|
| 209 |
Args:
|
| 210 |
path (str): Path to the file
|
| 211 |
"""
|
| 212 |
-
self.load_state_dict(torch.load(path))
|
| 213 |
self.to(self.torch_dtype).to(self.device)
|
| 214 |
|
| 215 |
def save_backbone(self, path):
|
|
|
|
| 167 |
Returns:
|
| 168 |
None
|
| 169 |
"""
|
| 170 |
+
model = torch.load(path, map_location=self.device)
|
| 171 |
if head_name in self.heads:
|
| 172 |
num_classes = model['weight'].shape[0]
|
| 173 |
self.heads[head_name].load_state_dict(model)
|
|
|
|
| 209 |
Args:
|
| 210 |
path (str): Path to the file
|
| 211 |
"""
|
| 212 |
+
self.load_state_dict(torch.load(path, map_location=self.device))
|
| 213 |
self.to(self.torch_dtype).to(self.device)
|
| 214 |
|
| 215 |
def save_backbone(self, path):
|