jays009 commited on
Commit
e97dbab
·
verified ·
1 Parent(s): 9c2da5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -23,12 +23,24 @@ def download_model():
23
  # Load the model from Hugging Face
24
  def load_model(model_path):
25
  model = models.resnet50(pretrained=False)
26
- model.fc = nn.Linear(model.fc.in_features, num_classes)
27
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
 
 
 
 
 
 
 
 
28
  model.eval()
29
- logging.info("Model loaded successfully. Ready for inference.")
30
  return model
31
 
 
 
 
 
 
32
  # Download the model and load it
33
  model_path = download_model()
34
  model = load_model(model_path)
 
23
  # Load the model from Hugging Face
24
  def load_model(model_path):
25
  model = models.resnet50(pretrained=False)
26
+ model.fc = nn.Linear(model.fc.in_features, 3) # Adjust for 3 classes
27
+ checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
28
+
29
+ # Check if it is a checkpoint and extract model state dict
30
+ if 'model_state_dict' in checkpoint:
31
+ state_dict = checkpoint['model_state_dict']
32
+ else:
33
+ state_dict = checkpoint
34
+
35
+ model.load_state_dict(state_dict)
36
  model.eval()
 
37
  return model
38
 
39
+ # Path to your model
40
+ model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
41
+ model = load_model(model_path)
42
+
43
+
44
  # Download the model and load it
45
  model_path = download_model()
46
  model = load_model(model_path)