jays009 commited on
Commit
aff9d06
·
verified ·
1 Parent(s): e97dbab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -23,19 +23,24 @@ def download_model():
23
  # Load the model from Hugging Face
24
  def load_model(model_path):
25
  model = models.resnet50(pretrained=False)
26
- model.fc = nn.Linear(model.fc.in_features, 3) # Adjust for 3 classes
 
 
 
 
27
  checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
28
-
29
- # Check if it is a checkpoint and extract model state dict
30
- if 'model_state_dict' in checkpoint:
31
- state_dict = checkpoint['model_state_dict']
32
- else:
33
- state_dict = checkpoint
34
-
35
- model.load_state_dict(state_dict)
36
  model.eval()
37
  return model
38
 
 
39
  # Path to your model
40
  model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
41
  model = load_model(model_path)
 
23
  # Load the model from Hugging Face
24
  def load_model(model_path):
25
  model = models.resnet50(pretrained=False)
26
+ num_features = model.fc.in_features
27
+ model.fc = nn.Sequential(
28
+ nn.Dropout(0.5),
29
+ nn.Linear(num_features, 3) # 3 classes
30
+ )
31
  checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
32
+ model.load_state_dict(checkpoint['model_state_dict'])
33
+
34
+ # Rename keys to match the model definition
35
+ state_dict['fc.weight'] = state_dict.pop('fc.1.weight')
36
+ state_dict['fc.bias'] = state_dict.pop('fc.1.bias')
37
+
38
+ # Load the modified state dict
39
+ model.load_state_dict(state_dict)
40
  model.eval()
41
  return model
42
 
43
+
44
  # Path to your model
45
  model_path = hf_hub_download(repo_id="jays009/Resnet3", filename="model.pth")
46
  model = load_model(model_path)