Spaces:
Running
Running
admin
commited on
Commit
·
a31338b
1
Parent(s):
acd11f9
sync ms
Browse files
app.py
CHANGED
@@ -38,7 +38,7 @@ def _L(zh_txt: str):
|
|
38 |
|
39 |
def infer(input_img: str, checkpoint_file: str):
|
40 |
status = "Success"
|
41 |
-
outstr =
|
42 |
try:
|
43 |
model = Model()
|
44 |
model.restore(f"{MODEL_DIR}/{checkpoint_file}")
|
|
|
38 |
|
39 |
def infer(input_img: str, checkpoint_file: str):
|
40 |
status = "Success"
|
41 |
+
outstr = ""
|
42 |
try:
|
43 |
model = Model()
|
44 |
model.restore(f"{MODEL_DIR}/{checkpoint_file}")
|
model.py
CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
|
|
7 |
|
8 |
class Model(torch.jit.ScriptModule):
|
9 |
CHECKPOINT_FILENAME_PATTERN = "model-{}.pth"
|
|
|
10 |
__constants__ = [
|
11 |
"_hidden1",
|
12 |
"_hidden2",
|
@@ -30,6 +31,7 @@ class Model(torch.jit.ScriptModule):
|
|
30 |
|
31 |
def __init__(self):
|
32 |
super(Model, self).__init__()
|
|
|
33 |
self._hidden1 = nn.Sequential(
|
34 |
nn.Conv2d(in_channels=3, out_channels=48, kernel_size=5, padding=2),
|
35 |
nn.BatchNorm2d(num_features=48),
|
@@ -88,6 +90,7 @@ class Model(torch.jit.ScriptModule):
|
|
88 |
)
|
89 |
self._hidden9 = nn.Sequential(nn.Linear(192 * 7 * 7, 3072), nn.ReLU())
|
90 |
self._hidden10 = nn.Sequential(nn.Linear(3072, 3072), nn.ReLU())
|
|
|
91 |
self._digit_length = nn.Sequential(nn.Linear(3072, 7))
|
92 |
self._digit1 = nn.Sequential(nn.Linear(3072, 11))
|
93 |
self._digit2 = nn.Sequential(nn.Linear(3072, 11))
|
@@ -108,12 +111,14 @@ class Model(torch.jit.ScriptModule):
|
|
108 |
x = x.view(x.size(0), 192 * 7 * 7)
|
109 |
x = self._hidden9(x)
|
110 |
x = self._hidden10(x)
|
|
|
111 |
length_logits = self._digit_length(x)
|
112 |
digit1_logits = self._digit1(x)
|
113 |
digit2_logits = self._digit2(x)
|
114 |
digit3_logits = self._digit3(x)
|
115 |
digit4_logits = self._digit4(x)
|
116 |
digit5_logits = self._digit5(x)
|
|
|
117 |
return (
|
118 |
length_logits,
|
119 |
digit1_logits,
|
@@ -147,6 +152,10 @@ class Model(torch.jit.ScriptModule):
|
|
147 |
|
148 |
def restore(self, path_to_checkpoint_file):
|
149 |
self.load_state_dict(
|
150 |
-
torch.load(
|
|
|
|
|
|
|
151 |
)
|
152 |
-
|
|
|
|
7 |
|
8 |
class Model(torch.jit.ScriptModule):
|
9 |
CHECKPOINT_FILENAME_PATTERN = "model-{}.pth"
|
10 |
+
|
11 |
__constants__ = [
|
12 |
"_hidden1",
|
13 |
"_hidden2",
|
|
|
31 |
|
32 |
def __init__(self):
|
33 |
super(Model, self).__init__()
|
34 |
+
|
35 |
self._hidden1 = nn.Sequential(
|
36 |
nn.Conv2d(in_channels=3, out_channels=48, kernel_size=5, padding=2),
|
37 |
nn.BatchNorm2d(num_features=48),
|
|
|
90 |
)
|
91 |
self._hidden9 = nn.Sequential(nn.Linear(192 * 7 * 7, 3072), nn.ReLU())
|
92 |
self._hidden10 = nn.Sequential(nn.Linear(3072, 3072), nn.ReLU())
|
93 |
+
|
94 |
self._digit_length = nn.Sequential(nn.Linear(3072, 7))
|
95 |
self._digit1 = nn.Sequential(nn.Linear(3072, 11))
|
96 |
self._digit2 = nn.Sequential(nn.Linear(3072, 11))
|
|
|
111 |
x = x.view(x.size(0), 192 * 7 * 7)
|
112 |
x = self._hidden9(x)
|
113 |
x = self._hidden10(x)
|
114 |
+
|
115 |
length_logits = self._digit_length(x)
|
116 |
digit1_logits = self._digit1(x)
|
117 |
digit2_logits = self._digit2(x)
|
118 |
digit3_logits = self._digit3(x)
|
119 |
digit4_logits = self._digit4(x)
|
120 |
digit5_logits = self._digit5(x)
|
121 |
+
|
122 |
return (
|
123 |
length_logits,
|
124 |
digit1_logits,
|
|
|
152 |
|
153 |
def restore(self, path_to_checkpoint_file):
|
154 |
self.load_state_dict(
|
155 |
+
torch.load(
|
156 |
+
path_to_checkpoint_file,
|
157 |
+
map_location=torch.device("cpu"),
|
158 |
+
)
|
159 |
)
|
160 |
+
step = int(path_to_checkpoint_file.split("model-")[-1][:-4])
|
161 |
+
return step
|