Spaces:
Sleeping
Sleeping
import task | |
import deit | |
import trocr_models | |
import torch | |
import fairseq | |
from fairseq import utils | |
from fairseq_cli import generate | |
from PIL import Image | |
import torchvision.transforms as transforms | |
def init(model_path, beam=5): | |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( | |
[model_path], | |
arg_overrides={"beam": beam, "task": "text_recognition", "data": "", "fp16": False}) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model[0].to(device) | |
img_transform = transforms.Compose([ | |
transforms.Resize((384, 384), interpolation=3), | |
transforms.ToTensor(), | |
transforms.Normalize(0.5, 0.5) | |
]) | |
generator = task.build_generator( | |
model, cfg.generation, extra_gen_cls_kwargs={'lm_model': None, 'lm_weight': None} | |
) | |
bpe = task.build_bpe(cfg.bpe) | |
return model, cfg, task, generator, bpe, img_transform, device | |
def preprocess(img_path, img_transform): | |
im = Image.open(img_path).convert('RGB').resize((384, 384)) | |
im = img_transform(im).unsqueeze(0).to(device).float() | |
sample = { | |
'net_input': {"imgs": im}, | |
} | |
return sample | |
def get_text(cfg, generator, model, sample, bpe): | |
decoder_output = task.inference_step(generator, model, sample, prefix_tokens=None, constraints=None) | |
decoder_output = decoder_output[0][0] #top1 | |
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( | |
hypo_tokens=decoder_output["tokens"].int().cpu(), | |
src_str="", | |
alignment=decoder_output["alignment"], | |
align_dict=None, | |
tgt_dict=model[0].decoder.dictionary, | |
remove_bpe=cfg.common_eval.post_process, | |
extra_symbols_to_ignore=generate.get_symbols_to_strip_from_output(generator), | |
) | |
detok_hypo_str = bpe.decode(hypo_str) | |
return detok_hypo_str | |
if __name__ == '__main__': | |
model_path = 'path/to/model' | |
jpg_path = "path/to/pic" | |
beam = 5 | |
model, cfg, task, generator, bpe, img_transform, device = init(model_path, beam) | |
sample = preprocess(jpg_path, img_transform) | |
text = get_text(cfg, generator, model, sample, bpe) | |
print(text) | |
print('done') | |