admin commited on
Commit
6e9a4e4
·
1 Parent(s): f723eb2

fix squeezenet load

Browse files
Files changed (2) hide show
  1. app.py +5 -4
  2. model.py +9 -3
app.py CHANGED
@@ -138,14 +138,15 @@ def infer(audio_path: str, log_name: str):
138
  weight_path=f"{MODEL_DIR}/{log_name}.pt",
139
  )
140
 
 
 
 
 
141
  except Exception as e:
142
  return f"{e}", None
143
 
144
- input_size = eval_net.get_input_size()
145
- embeded_input = embed(input, input_size)
146
- output = list(eval_net.forward(embeded_input))
147
- outputs = []
148
  index = 0
 
149
  for y in output:
150
  preds = list(y.T)
151
  for pred in preds:
 
138
  weight_path=f"{MODEL_DIR}/{log_name}.pt",
139
  )
140
 
141
+ input_size = eval_net.get_input_size()
142
+ embeded_input = embed(input, input_size)
143
+ output = list(eval_net.forward(embeded_input))
144
+
145
  except Exception as e:
146
  return f"{e}", None
147
 
 
 
 
 
148
  index = 0
149
+ outputs = []
150
  for y in output:
151
  preds = list(y.T)
152
  for pred in preds:
model.py CHANGED
@@ -56,9 +56,15 @@ class EvalNet:
56
  torch.load(weight_path)
57
  if torch.cuda.is_available()
58
  else torch.load(weight_path, map_location="cpu")
59
- ) # self.model.load_state_dict(checkpoint, False)
60
- self.model.load_state_dict(checkpoint["model"], False)
61
- self.classifier.load_state_dict(checkpoint["classifier"], False)
 
 
 
 
 
 
62
  if torch.cuda.is_available():
63
  self.model = self.model.cuda()
64
  self.classifier = self.classifier.cuda()
 
56
  torch.load(weight_path)
57
  if torch.cuda.is_available()
58
  else torch.load(weight_path, map_location="cpu")
59
+ )
60
+
61
+ if self.type == "squeezenet":
62
+ self.model.load_state_dict(checkpoint, False)
63
+
64
+ else:
65
+ self.model.load_state_dict(checkpoint["model"], False)
66
+ self.classifier.load_state_dict(checkpoint["classifier"], False)
67
+
68
  if torch.cuda.is_available():
69
  self.model = self.model.cuda()
70
  self.classifier = self.classifier.cuda()