Spaces:
Running
Running
admin
commited on
Commit
·
a31338b
1
Parent(s):
acd11f9
sync ms
Browse files
app.py
CHANGED
|
@@ -38,7 +38,7 @@ def _L(zh_txt: str):
|
|
| 38 |
|
| 39 |
def infer(input_img: str, checkpoint_file: str):
|
| 40 |
status = "Success"
|
| 41 |
-
outstr =
|
| 42 |
try:
|
| 43 |
model = Model()
|
| 44 |
model.restore(f"{MODEL_DIR}/{checkpoint_file}")
|
|
|
|
| 38 |
|
| 39 |
def infer(input_img: str, checkpoint_file: str):
|
| 40 |
status = "Success"
|
| 41 |
+
outstr = ""
|
| 42 |
try:
|
| 43 |
model = Model()
|
| 44 |
model.restore(f"{MODEL_DIR}/{checkpoint_file}")
|
model.py
CHANGED
|
@@ -7,6 +7,7 @@ import torch.nn as nn
|
|
| 7 |
|
| 8 |
class Model(torch.jit.ScriptModule):
|
| 9 |
CHECKPOINT_FILENAME_PATTERN = "model-{}.pth"
|
|
|
|
| 10 |
__constants__ = [
|
| 11 |
"_hidden1",
|
| 12 |
"_hidden2",
|
|
@@ -30,6 +31,7 @@ class Model(torch.jit.ScriptModule):
|
|
| 30 |
|
| 31 |
def __init__(self):
|
| 32 |
super(Model, self).__init__()
|
|
|
|
| 33 |
self._hidden1 = nn.Sequential(
|
| 34 |
nn.Conv2d(in_channels=3, out_channels=48, kernel_size=5, padding=2),
|
| 35 |
nn.BatchNorm2d(num_features=48),
|
|
@@ -88,6 +90,7 @@ class Model(torch.jit.ScriptModule):
|
|
| 88 |
)
|
| 89 |
self._hidden9 = nn.Sequential(nn.Linear(192 * 7 * 7, 3072), nn.ReLU())
|
| 90 |
self._hidden10 = nn.Sequential(nn.Linear(3072, 3072), nn.ReLU())
|
|
|
|
| 91 |
self._digit_length = nn.Sequential(nn.Linear(3072, 7))
|
| 92 |
self._digit1 = nn.Sequential(nn.Linear(3072, 11))
|
| 93 |
self._digit2 = nn.Sequential(nn.Linear(3072, 11))
|
|
@@ -108,12 +111,14 @@ class Model(torch.jit.ScriptModule):
|
|
| 108 |
x = x.view(x.size(0), 192 * 7 * 7)
|
| 109 |
x = self._hidden9(x)
|
| 110 |
x = self._hidden10(x)
|
|
|
|
| 111 |
length_logits = self._digit_length(x)
|
| 112 |
digit1_logits = self._digit1(x)
|
| 113 |
digit2_logits = self._digit2(x)
|
| 114 |
digit3_logits = self._digit3(x)
|
| 115 |
digit4_logits = self._digit4(x)
|
| 116 |
digit5_logits = self._digit5(x)
|
|
|
|
| 117 |
return (
|
| 118 |
length_logits,
|
| 119 |
digit1_logits,
|
|
@@ -147,6 +152,10 @@ class Model(torch.jit.ScriptModule):
|
|
| 147 |
|
| 148 |
def restore(self, path_to_checkpoint_file):
|
| 149 |
self.load_state_dict(
|
| 150 |
-
torch.load(
|
|
|
|
|
|
|
|
|
|
| 151 |
)
|
| 152 |
-
|
|
|
|
|
|
| 7 |
|
| 8 |
class Model(torch.jit.ScriptModule):
|
| 9 |
CHECKPOINT_FILENAME_PATTERN = "model-{}.pth"
|
| 10 |
+
|
| 11 |
__constants__ = [
|
| 12 |
"_hidden1",
|
| 13 |
"_hidden2",
|
|
|
|
| 31 |
|
| 32 |
def __init__(self):
|
| 33 |
super(Model, self).__init__()
|
| 34 |
+
|
| 35 |
self._hidden1 = nn.Sequential(
|
| 36 |
nn.Conv2d(in_channels=3, out_channels=48, kernel_size=5, padding=2),
|
| 37 |
nn.BatchNorm2d(num_features=48),
|
|
|
|
| 90 |
)
|
| 91 |
self._hidden9 = nn.Sequential(nn.Linear(192 * 7 * 7, 3072), nn.ReLU())
|
| 92 |
self._hidden10 = nn.Sequential(nn.Linear(3072, 3072), nn.ReLU())
|
| 93 |
+
|
| 94 |
self._digit_length = nn.Sequential(nn.Linear(3072, 7))
|
| 95 |
self._digit1 = nn.Sequential(nn.Linear(3072, 11))
|
| 96 |
self._digit2 = nn.Sequential(nn.Linear(3072, 11))
|
|
|
|
| 111 |
x = x.view(x.size(0), 192 * 7 * 7)
|
| 112 |
x = self._hidden9(x)
|
| 113 |
x = self._hidden10(x)
|
| 114 |
+
|
| 115 |
length_logits = self._digit_length(x)
|
| 116 |
digit1_logits = self._digit1(x)
|
| 117 |
digit2_logits = self._digit2(x)
|
| 118 |
digit3_logits = self._digit3(x)
|
| 119 |
digit4_logits = self._digit4(x)
|
| 120 |
digit5_logits = self._digit5(x)
|
| 121 |
+
|
| 122 |
return (
|
| 123 |
length_logits,
|
| 124 |
digit1_logits,
|
|
|
|
| 152 |
|
| 153 |
def restore(self, path_to_checkpoint_file):
|
| 154 |
self.load_state_dict(
|
| 155 |
+
torch.load(
|
| 156 |
+
path_to_checkpoint_file,
|
| 157 |
+
map_location=torch.device("cpu"),
|
| 158 |
+
)
|
| 159 |
)
|
| 160 |
+
step = int(path_to_checkpoint_file.split("model-")[-1][:-4])
|
| 161 |
+
return step
|