{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "aaed9cbc", "metadata": {}, "outputs": [], "source": [ "import task\n", "import deit\n", "import trocr_models\n", "import torch\n", "import fairseq\n", "from fairseq import utils\n", "from fairseq_cli import generate\n", "from PIL import Image\n", "import torchvision.transforms as transforms\n", "\n", "\n", "def init(model_path, beam=5):\n", " model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(\n", " [model_path],\n", " arg_overrides={\"beam\": beam, \"task\": \"text_recognition\", \"data\": \"\", \"fp16\": False})\n", "\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " model[0].to(device)\n", "\n", " img_transform = transforms.Compose([\n", " transforms.Resize((384, 384), interpolation=3),\n", " transforms.ToTensor(),\n", " transforms.Normalize(0.5, 0.5)\n", " ])\n", "\n", " generator = task.build_generator(\n", " model, cfg.generation, extra_gen_cls_kwargs={'lm_model': None, 'lm_weight': None}\n", " )\n", "\n", " bpe = task.build_bpe(cfg.bpe)\n", "\n", " return model, cfg, task, generator, bpe, img_transform, device\n", "\n", "\n", "def preprocess(img_path, img_transform):\n", " im = Image.open(img_path).convert('RGB').resize((384, 384))\n", " im = img_transform(im).unsqueeze(0).to(device).float()\n", "\n", " sample = {\n", " 'net_input': {\"imgs\": im},\n", " }\n", "\n", " return sample\n", "\n", "\n", "def get_text(cfg, generator, model, sample, bpe):\n", " decoder_output = task.inference_step(generator, model, sample, prefix_tokens=None, constraints=None)\n", " decoder_output = decoder_output[0][0] #top1\n", "\n", " hypo_tokens, hypo_str, alignment = utils.post_process_prediction(\n", " hypo_tokens=decoder_output[\"tokens\"].int().cpu(),\n", " src_str=\"\",\n", " alignment=decoder_output[\"alignment\"],\n", " align_dict=None,\n", " tgt_dict=model[0].decoder.dictionary,\n", " remove_bpe=cfg.common_eval.post_process,\n", " extra_symbols_to_ignore=generate.get_symbols_to_strip_from_output(generator),\n", " )\n", "\n", " detok_hypo_str = bpe.decode(hypo_str)\n", "\n", " return detok_hypo_str" ] }, { "cell_type": "code", "execution_count": null, "id": "b95c01e4", "metadata": {}, "outputs": [], "source": [ "model_path = 'path/to/model'\n", "jpg_path = \"path/to/pic\"\n", "beam = 5\n", "\n", "model, cfg, task, generator, bpe, img_transform, device = init(model_path, beam)\n", "\n", "sample = preprocess(jpg_path, img_transform)\n", "\n", "text = get_text(cfg, generator, model, sample, bpe)\n", "\n", "print(text)\n", "\n", "print('done')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.5 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "vscode": { "interpreter": { "hash": "0b8488e5f98ef3932f4ff0893213e55e6ba8b00dde307078d0f3efb25017ce11" } } }, "nbformat": 4, "nbformat_minor": 5 }