admin commited on
Commit
a31338b
·
1 Parent(s): acd11f9
Files changed (2) hide show
  1. app.py +1 -1
  2. model.py +11 -2
app.py CHANGED
@@ -38,7 +38,7 @@ def _L(zh_txt: str):
38
 
39
  def infer(input_img: str, checkpoint_file: str):
40
  status = "Success"
41
- outstr = None
42
  try:
43
  model = Model()
44
  model.restore(f"{MODEL_DIR}/{checkpoint_file}")
 
38
 
39
  def infer(input_img: str, checkpoint_file: str):
40
  status = "Success"
41
+ outstr = ""
42
  try:
43
  model = Model()
44
  model.restore(f"{MODEL_DIR}/{checkpoint_file}")
model.py CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
7
 
8
  class Model(torch.jit.ScriptModule):
9
  CHECKPOINT_FILENAME_PATTERN = "model-{}.pth"
 
10
  __constants__ = [
11
  "_hidden1",
12
  "_hidden2",
@@ -30,6 +31,7 @@ class Model(torch.jit.ScriptModule):
30
 
31
  def __init__(self):
32
  super(Model, self).__init__()
 
33
  self._hidden1 = nn.Sequential(
34
  nn.Conv2d(in_channels=3, out_channels=48, kernel_size=5, padding=2),
35
  nn.BatchNorm2d(num_features=48),
@@ -88,6 +90,7 @@ class Model(torch.jit.ScriptModule):
88
  )
89
  self._hidden9 = nn.Sequential(nn.Linear(192 * 7 * 7, 3072), nn.ReLU())
90
  self._hidden10 = nn.Sequential(nn.Linear(3072, 3072), nn.ReLU())
 
91
  self._digit_length = nn.Sequential(nn.Linear(3072, 7))
92
  self._digit1 = nn.Sequential(nn.Linear(3072, 11))
93
  self._digit2 = nn.Sequential(nn.Linear(3072, 11))
@@ -108,12 +111,14 @@ class Model(torch.jit.ScriptModule):
108
  x = x.view(x.size(0), 192 * 7 * 7)
109
  x = self._hidden9(x)
110
  x = self._hidden10(x)
 
111
  length_logits = self._digit_length(x)
112
  digit1_logits = self._digit1(x)
113
  digit2_logits = self._digit2(x)
114
  digit3_logits = self._digit3(x)
115
  digit4_logits = self._digit4(x)
116
  digit5_logits = self._digit5(x)
 
117
  return (
118
  length_logits,
119
  digit1_logits,
@@ -147,6 +152,10 @@ class Model(torch.jit.ScriptModule):
147
 
148
  def restore(self, path_to_checkpoint_file):
149
  self.load_state_dict(
150
- torch.load(path_to_checkpoint_file, map_location=torch.device("cpu"))
 
 
 
151
  )
152
- return int(path_to_checkpoint_file.split("model-")[-1][:-4])
 
 
7
 
8
  class Model(torch.jit.ScriptModule):
9
  CHECKPOINT_FILENAME_PATTERN = "model-{}.pth"
10
+
11
  __constants__ = [
12
  "_hidden1",
13
  "_hidden2",
 
31
 
32
  def __init__(self):
33
  super(Model, self).__init__()
34
+
35
  self._hidden1 = nn.Sequential(
36
  nn.Conv2d(in_channels=3, out_channels=48, kernel_size=5, padding=2),
37
  nn.BatchNorm2d(num_features=48),
 
90
  )
91
  self._hidden9 = nn.Sequential(nn.Linear(192 * 7 * 7, 3072), nn.ReLU())
92
  self._hidden10 = nn.Sequential(nn.Linear(3072, 3072), nn.ReLU())
93
+
94
  self._digit_length = nn.Sequential(nn.Linear(3072, 7))
95
  self._digit1 = nn.Sequential(nn.Linear(3072, 11))
96
  self._digit2 = nn.Sequential(nn.Linear(3072, 11))
 
111
  x = x.view(x.size(0), 192 * 7 * 7)
112
  x = self._hidden9(x)
113
  x = self._hidden10(x)
114
+
115
  length_logits = self._digit_length(x)
116
  digit1_logits = self._digit1(x)
117
  digit2_logits = self._digit2(x)
118
  digit3_logits = self._digit3(x)
119
  digit4_logits = self._digit4(x)
120
  digit5_logits = self._digit5(x)
121
+
122
  return (
123
  length_logits,
124
  digit1_logits,
 
152
 
153
  def restore(self, path_to_checkpoint_file):
154
  self.load_state_dict(
155
+ torch.load(
156
+ path_to_checkpoint_file,
157
+ map_location=torch.device("cpu"),
158
+ )
159
  )
160
+ step = int(path_to_checkpoint_file.split("model-")[-1][:-4])
161
+ return step