Spaces:
Running
Running
admin
commited on
Commit
·
6e9a4e4
1
Parent(s):
f723eb2
fix squeezenet load
Browse files
app.py
CHANGED
@@ -138,14 +138,15 @@ def infer(audio_path: str, log_name: str):
|
|
138 |
weight_path=f"{MODEL_DIR}/{log_name}.pt",
|
139 |
)
|
140 |
|
|
|
|
|
|
|
|
|
141 |
except Exception as e:
|
142 |
return f"{e}", None
|
143 |
|
144 |
-
input_size = eval_net.get_input_size()
|
145 |
-
embeded_input = embed(input, input_size)
|
146 |
-
output = list(eval_net.forward(embeded_input))
|
147 |
-
outputs = []
|
148 |
index = 0
|
|
|
149 |
for y in output:
|
150 |
preds = list(y.T)
|
151 |
for pred in preds:
|
|
|
138 |
weight_path=f"{MODEL_DIR}/{log_name}.pt",
|
139 |
)
|
140 |
|
141 |
+
input_size = eval_net.get_input_size()
|
142 |
+
embeded_input = embed(input, input_size)
|
143 |
+
output = list(eval_net.forward(embeded_input))
|
144 |
+
|
145 |
except Exception as e:
|
146 |
return f"{e}", None
|
147 |
|
|
|
|
|
|
|
|
|
148 |
index = 0
|
149 |
+
outputs = []
|
150 |
for y in output:
|
151 |
preds = list(y.T)
|
152 |
for pred in preds:
|
model.py
CHANGED
@@ -56,9 +56,15 @@ class EvalNet:
|
|
56 |
torch.load(weight_path)
|
57 |
if torch.cuda.is_available()
|
58 |
else torch.load(weight_path, map_location="cpu")
|
59 |
-
)
|
60 |
-
|
61 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
if torch.cuda.is_available():
|
63 |
self.model = self.model.cuda()
|
64 |
self.classifier = self.classifier.cuda()
|
|
|
56 |
torch.load(weight_path)
|
57 |
if torch.cuda.is_available()
|
58 |
else torch.load(weight_path, map_location="cpu")
|
59 |
+
)
|
60 |
+
|
61 |
+
if self.type == "squeezenet":
|
62 |
+
self.model.load_state_dict(checkpoint, False)
|
63 |
+
|
64 |
+
else:
|
65 |
+
self.model.load_state_dict(checkpoint["model"], False)
|
66 |
+
self.classifier.load_state_dict(checkpoint["classifier"], False)
|
67 |
+
|
68 |
if torch.cuda.is_available():
|
69 |
self.model = self.model.cuda()
|
70 |
self.classifier = self.classifier.cuda()
|