Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -23,12 +23,24 @@ def download_model():
|
|
23 |
# Load the model from Hugging Face
|
24 |
def load_model(model_path):
|
25 |
model = models.resnet50(pretrained=False)
|
26 |
-
model.fc = nn.Linear(model.fc.in_features,
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
model.eval()
|
29 |
-
logging.info("Model loaded successfully. Ready for inference.")
|
30 |
return model
|
31 |
|
|
|
|
|
|
|
|
|
|
|
32 |
# Download the model and load it
|
33 |
model_path = download_model()
|
34 |
model = load_model(model_path)
|
|
|
23 |
# Load the model from Hugging Face
|
24 |
def load_model(model_path):
|
25 |
model = models.resnet50(pretrained=False)
|
26 |
+
model.fc = nn.Linear(model.fc.in_features, 3) # Adjust for 3 classes
|
27 |
+
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
|
28 |
+
|
29 |
+
# Check if it is a checkpoint and extract model state dict
|
30 |
+
if 'model_state_dict' in checkpoint:
|
31 |
+
state_dict = checkpoint['model_state_dict']
|
32 |
+
else:
|
33 |
+
state_dict = checkpoint
|
34 |
+
|
35 |
+
model.load_state_dict(state_dict)
|
36 |
model.eval()
|
|
|
37 |
return model
|
38 |
|
39 |
+
# Path to your model
|
40 |
+
model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
|
41 |
+
model = load_model(model_path)
|
42 |
+
|
43 |
+
|
44 |
# Download the model and load it
|
45 |
model_path = download_model()
|
46 |
model = load_model(model_path)
|