jays009 commited on
Commit
811ddcb
·
verified ·
1 Parent(s): 5a8efa9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -28,25 +28,29 @@ def download_model():
28
  # Load the model from Hugging Face
29
 
30
  def load_model(model_path):
31
- try:
32
- model = models.resnet50(pretrained=False)
33
- num_features = model.fc.in_features
34
- model.fc = nn.Sequential(
35
- nn.Dropout(0.5),
36
- nn.Linear(num_features, num_classes)
37
- )
38
- checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
39
-
40
- # Ensure compatibility by handling key mismatches
41
- model_state_dict = checkpoint['model_state_dict']
42
- for key in list(model_state_dict.keys()):
43
- if key.startswith('fc.1'):
44
- model_state_dict[key.replace('fc.1', 'fc')] = model_state_dict.pop(key)
45
-
46
- model.load_state_dict(model_state_dict)
47
- model.eval()
48
- logging.info("Model loaded successfully.")
49
- return model
 
 
 
 
50
  except Exception as e:
51
  logging.error(f"Failed to load model: {e}")
52
  raise
 
28
  # Load the model from Hugging Face
29
 
30
  def load_model(model_path):
31
+ model = models.resnet50(pretrained=False)
32
+ num_features = model.fc.in_features
33
+ model.fc = nn.Sequential(
34
+ nn.Dropout(0.5),
35
+ nn.Linear(num_features, 3) # 3 classes
36
+ )
37
+
38
+ # Load the checkpoint
39
+ checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
40
+
41
+ # Adjust for state dict mismatch by renaming keys
42
+ state_dict = checkpoint['model_state_dict']
43
+ new_state_dict = {}
44
+ for k, v in state_dict.items():
45
+ if k == "fc.weight" or k == "fc.bias":
46
+ new_state_dict[f"fc.1.{k.split('.')[-1]}"] = v
47
+ else:
48
+ new_state_dict[k] = v
49
+
50
+ model.load_state_dict(new_state_dict, strict=False)
51
+ model.eval()
52
+ return model
53
+
54
  except Exception as e:
55
  logging.error(f"Failed to load model: {e}")
56
  raise