Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -23,19 +23,24 @@ def download_model():
|
|
23 |
# Load the model from Hugging Face
|
24 |
def load_model(model_path):
|
25 |
model = models.resnet50(pretrained=False)
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
model.eval()
|
37 |
return model
|
38 |
|
|
|
39 |
# Path to your model
|
40 |
model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
|
41 |
model = load_model(model_path)
|
|
|
23 |
# Load the model from Hugging Face
|
24 |
def load_model(model_path):
|
25 |
model = models.resnet50(pretrained=False)
|
26 |
+
num_features = model.fc.in_features
|
27 |
+
model.fc = nn.Sequential(
|
28 |
+
nn.Dropout(0.5),
|
29 |
+
nn.Linear(num_features, 3) # 3 classes
|
30 |
+
)
|
31 |
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
|
32 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
33 |
+
|
34 |
+
# Rename keys to match the model definition
|
35 |
+
state_dict['fc.weight'] = state_dict.pop('fc.1.weight')
|
36 |
+
state_dict['fc.bias'] = state_dict.pop('fc.1.bias')
|
37 |
+
|
38 |
+
# Load the modified state dict
|
39 |
+
model.load_state_dict(state_dict)
|
40 |
model.eval()
|
41 |
return model
|
42 |
|
43 |
+
|
44 |
# Path to your model
|
45 |
model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
|
46 |
model = load_model(model_path)
|