Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -28,25 +28,29 @@ def download_model():
|
|
| 28 |
# Load the model from Hugging Face
|
| 29 |
|
| 30 |
def load_model(model_path):
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
except Exception as e:
|
| 51 |
logging.error(f"Failed to load model: {e}")
|
| 52 |
raise
|
|
|
|
| 28 |
# Load the model from Hugging Face
|
| 29 |
|
| 30 |
def load_model(model_path):
|
| 31 |
+
model = models.resnet50(pretrained=False)
|
| 32 |
+
num_features = model.fc.in_features
|
| 33 |
+
model.fc = nn.Sequential(
|
| 34 |
+
nn.Dropout(0.5),
|
| 35 |
+
nn.Linear(num_features, 3) # 3 classes
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Load the checkpoint
|
| 39 |
+
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
|
| 40 |
+
|
| 41 |
+
# Adjust for state dict mismatch by renaming keys
|
| 42 |
+
state_dict = checkpoint['model_state_dict']
|
| 43 |
+
new_state_dict = {}
|
| 44 |
+
for k, v in state_dict.items():
|
| 45 |
+
if k == "fc.weight" or k == "fc.bias":
|
| 46 |
+
new_state_dict[f"fc.1.{k.split('.')[-1]}"] = v
|
| 47 |
+
else:
|
| 48 |
+
new_state_dict[k] = v
|
| 49 |
+
|
| 50 |
+
model.load_state_dict(new_state_dict, strict=False)
|
| 51 |
+
model.eval()
|
| 52 |
+
return model
|
| 53 |
+
|
| 54 |
except Exception as e:
|
| 55 |
logging.error(f"Failed to load model: {e}")
|
| 56 |
raise
|