File size: 3,641 Bytes
278c80b
 
 
 
4b55707
 
278c80b
 
 
 
 
4b55707
278c80b
4b55707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278c80b
 
 
4b55707
a31338b
278c80b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b55707
 
 
278c80b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b55707
278c80b
4b55707
278c80b
 
 
 
4b55707
 
 
 
278c80b
4b55707
b780a4b
278c80b
4b55707
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import torch
import random
import warnings
import modelscope
import huggingface_hub
import gradio as gr
from PIL import Image
from model import Model
from torchvision import transforms

EN_US = os.getenv("LANG") != "zh_CN.UTF-8"

MODEL_DIR = (
    huggingface_hub.snapshot_download(
        "Genius-Society/svhn",
        cache_dir="./__pycache__",
    )
    if EN_US
    else modelscope.snapshot_download(
        "Genius-Society/svhn",
        cache_dir="./__pycache__",
    )
)

ZH2EN = {
    "上传图片": "Upload an image",
    "状态栏": "Status",
    "选择模型": "Select a model",
    "识别结果": "Recognition result",
    "门牌号识别": "Door Number Recognition",
}


def _L(zh_txt: str):
    return ZH2EN[zh_txt] if EN_US else zh_txt


def infer(input_img: str, checkpoint_file: str):
    status = "Success"
    outstr = ""
    try:
        model = Model()
        model.restore(f"{MODEL_DIR}/{checkpoint_file}")
        with torch.no_grad():
            transform = transforms.Compose(
                [
                    transforms.Resize([64, 64]),
                    transforms.CenterCrop([54, 54]),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                ]
            )
            image = Image.open(input_img)
            image = image.convert("RGB")
            image = transform(image)
            images = image.unsqueeze(dim=0)
            (
                length_logits,
                digit1_logits,
                digit2_logits,
                digit3_logits,
                digit4_logits,
                digit5_logits,
            ) = model.eval()(images)
            length_prediction = length_logits.max(1)[1]
            digit1_prediction = digit1_logits.max(1)[1]
            digit2_prediction = digit2_logits.max(1)[1]
            digit3_prediction = digit3_logits.max(1)[1]
            digit4_prediction = digit4_logits.max(1)[1]
            digit5_prediction = digit5_logits.max(1)[1]
            output = [
                digit1_prediction.item(),
                digit2_prediction.item(),
                digit3_prediction.item(),
                digit4_prediction.item(),
                digit5_prediction.item(),
            ]

            for i in range(length_prediction.item()):
                outstr += str(output[i])

    except Exception as e:
        status = f"{e}"

    return status, outstr


def get_files(dir_path=MODEL_DIR, ext=".pth"):
    files_and_folders = os.listdir(dir_path)
    outputs = []
    for file in files_and_folders:
        if file.endswith(ext):
            outputs.append(file)

    return outputs


if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    models = get_files()
    images = get_files(f"{MODEL_DIR}/examples", ".png")
    samples = []
    for img in images:
        samples.append(
            [
                f"{MODEL_DIR}/examples/{img}",
                models[random.randint(0, len(models) - 1)],
            ]
        )

    gr.Interface(
        fn=infer,
        inputs=[
            gr.Image(label=_L("上传图片"), type="filepath"),
            gr.Dropdown(
                label=_L("选择模型"),
                choices=models,
                value=models[0],
            ),
        ],
        outputs=[
            gr.Textbox(label=_L("状态栏"), show_copy_button=True),
            gr.Textbox(label=_L("识别结果"), show_copy_button=True),
        ],
        examples=samples,
        title=_L("门牌号识别"),
        flagging_mode="never",
        cache_examples=False,
    ).launch(ssr_mode=False)