import os import torch import random import warnings import modelscope import huggingface_hub import gradio as gr from PIL import Image from model import Model from torchvision import transforms EN_US = os.getenv("LANG") != "zh_CN.UTF-8" MODEL_DIR = ( huggingface_hub.snapshot_download( "Genius-Society/svhn", cache_dir="./__pycache__", ) if EN_US else modelscope.snapshot_download( "Genius-Society/svhn", cache_dir="./__pycache__", ) ) ZH2EN = { "上传图片": "Upload an image", "状态栏": "Status", "选择模型": "Select a model", "识别结果": "Recognition result", "门牌号识别": "Door Number Recognition", } def _L(zh_txt: str): return ZH2EN[zh_txt] if EN_US else zh_txt def infer(input_img: str, checkpoint_file: str): status = "Success" outstr = "" try: model = Model() model.restore(f"{MODEL_DIR}/{checkpoint_file}") with torch.no_grad(): transform = transforms.Compose( [ transforms.Resize([64, 64]), transforms.CenterCrop([54, 54]), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) image = Image.open(input_img) image = image.convert("RGB") image = transform(image) images = image.unsqueeze(dim=0) ( length_logits, digit1_logits, digit2_logits, digit3_logits, digit4_logits, digit5_logits, ) = model.eval()(images) length_prediction = length_logits.max(1)[1] digit1_prediction = digit1_logits.max(1)[1] digit2_prediction = digit2_logits.max(1)[1] digit3_prediction = digit3_logits.max(1)[1] digit4_prediction = digit4_logits.max(1)[1] digit5_prediction = digit5_logits.max(1)[1] output = [ digit1_prediction.item(), digit2_prediction.item(), digit3_prediction.item(), digit4_prediction.item(), digit5_prediction.item(), ] for i in range(length_prediction.item()): outstr += str(output[i]) except Exception as e: status = f"{e}" return status, outstr def get_files(dir_path=MODEL_DIR, ext=".pth"): files_and_folders = os.listdir(dir_path) outputs = [] for file in files_and_folders: if file.endswith(ext): outputs.append(file) return outputs if __name__ == "__main__": warnings.filterwarnings("ignore") models = get_files() images = get_files(f"{MODEL_DIR}/examples", ".png") samples = [] for img in images: samples.append( [ f"{MODEL_DIR}/examples/{img}", models[random.randint(0, len(models) - 1)], ] ) gr.Interface( fn=infer, inputs=[ gr.Image(label=_L("上传图片"), type="filepath"), gr.Dropdown( label=_L("选择模型"), choices=models, value=models[0], ), ], outputs=[ gr.Textbox(label=_L("状态栏"), show_copy_button=True), gr.Textbox(label=_L("识别结果"), show_copy_button=True), ], examples=samples, title=_L("门牌号识别"), flagging_mode="never", cache_examples=False, ).launch(ssr_mode=False)