svhn / app.py
admin
try upd req
4b55707
raw
history blame
3.64 kB
import os
import torch
import random
import warnings
import modelscope
import huggingface_hub
import gradio as gr
from PIL import Image
from model import Model
from torchvision import transforms
EN_US = os.getenv("LANG") != "zh_CN.UTF-8"
MODEL_DIR = (
huggingface_hub.snapshot_download(
"Genius-Society/svhn",
cache_dir="./__pycache__",
)
if EN_US
else modelscope.snapshot_download(
"Genius-Society/svhn",
cache_dir="./__pycache__",
)
)
ZH2EN = {
"上传图片": "Upload an image",
"状态栏": "Status",
"选择模型": "Select a model",
"识别结果": "Recognition result",
"门牌号识别": "Door Number Recognition",
}
def _L(zh_txt: str):
return ZH2EN[zh_txt] if EN_US else zh_txt
def infer(input_img: str, checkpoint_file: str):
status = "Success"
outstr = None
try:
model = Model()
model.restore(f"{MODEL_DIR}/{checkpoint_file}")
with torch.no_grad():
transform = transforms.Compose(
[
transforms.Resize([64, 64]),
transforms.CenterCrop([54, 54]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
image = Image.open(input_img)
image = image.convert("RGB")
image = transform(image)
images = image.unsqueeze(dim=0)
(
length_logits,
digit1_logits,
digit2_logits,
digit3_logits,
digit4_logits,
digit5_logits,
) = model.eval()(images)
length_prediction = length_logits.max(1)[1]
digit1_prediction = digit1_logits.max(1)[1]
digit2_prediction = digit2_logits.max(1)[1]
digit3_prediction = digit3_logits.max(1)[1]
digit4_prediction = digit4_logits.max(1)[1]
digit5_prediction = digit5_logits.max(1)[1]
output = [
digit1_prediction.item(),
digit2_prediction.item(),
digit3_prediction.item(),
digit4_prediction.item(),
digit5_prediction.item(),
]
for i in range(length_prediction.item()):
outstr += str(output[i])
except Exception as e:
status = f"{e}"
return status, outstr
def get_files(dir_path=MODEL_DIR, ext=".pth"):
files_and_folders = os.listdir(dir_path)
outputs = []
for file in files_and_folders:
if file.endswith(ext):
outputs.append(file)
return outputs
if __name__ == "__main__":
warnings.filterwarnings("ignore")
models = get_files()
images = get_files(f"{MODEL_DIR}/examples", ".png")
samples = []
for img in images:
samples.append(
[
f"{MODEL_DIR}/examples/{img}",
models[random.randint(0, len(models) - 1)],
]
)
gr.Interface(
fn=infer,
inputs=[
gr.Image(label=_L("上传图片"), type="filepath"),
gr.Dropdown(
label=_L("选择模型"),
choices=models,
value=models[0],
),
],
outputs=[
gr.Textbox(label=_L("状态栏"), show_copy_button=True),
gr.Textbox(label=_L("识别结果"), show_copy_button=True),
],
examples=samples,
title=_L("门牌号识别"),
flagging_mode="never",
cache_examples=False,
).launch(ssr_mode=False)