Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -28,25 +28,29 @@ def download_model():
|
|
28 |
# Load the model from Hugging Face
|
29 |
|
30 |
def load_model(model_path):
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
except Exception as e:
|
51 |
logging.error(f"Failed to load model: {e}")
|
52 |
raise
|
|
|
28 |
# Load the model from Hugging Face
|
29 |
|
30 |
def load_model(model_path):
|
31 |
+
model = models.resnet50(pretrained=False)
|
32 |
+
num_features = model.fc.in_features
|
33 |
+
model.fc = nn.Sequential(
|
34 |
+
nn.Dropout(0.5),
|
35 |
+
nn.Linear(num_features, 3) # 3 classes
|
36 |
+
)
|
37 |
+
|
38 |
+
# Load the checkpoint
|
39 |
+
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
|
40 |
+
|
41 |
+
# Adjust for state dict mismatch by renaming keys
|
42 |
+
state_dict = checkpoint['model_state_dict']
|
43 |
+
new_state_dict = {}
|
44 |
+
for k, v in state_dict.items():
|
45 |
+
if k == "fc.weight" or k == "fc.bias":
|
46 |
+
new_state_dict[f"fc.1.{k.split('.')[-1]}"] = v
|
47 |
+
else:
|
48 |
+
new_state_dict[k] = v
|
49 |
+
|
50 |
+
model.load_state_dict(new_state_dict, strict=False)
|
51 |
+
model.eval()
|
52 |
+
return model
|
53 |
+
|
54 |
except Exception as e:
|
55 |
logging.error(f"Failed to load model: {e}")
|
56 |
raise
|