Uddipan Basu Bir
Download checkpoint from HF hub in OcrReorderPipeline
a01cae7
raw
history blame
4.04 kB
import os
import json
import base64
from io import BytesIO
from PIL import Image
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from transformers import (
AutoProcessor,
LayoutLMv3Model,
T5ForConditionalGeneration,
AutoTokenizer
)
# ── 1) MODEL SETUP ─────────────────────────────────────────────────────
repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"
# Processor
processor = AutoProcessor.from_pretrained(
repo,
subfolder="preprocessor",
apply_ocr=False
)
# Encoder & Decoder
layout_model = LayoutLMv3Model.from_pretrained(repo).to("cpu").eval()
t5_model = T5ForConditionalGeneration.from_pretrained(repo).to("cpu").eval()
tokenizer = AutoTokenizer.from_pretrained(
repo, subfolder="preprocessor"
)
# Ensure decoder_start_token_id and bos_token_id are set
if t5_model.config.decoder_start_token_id is None:
fallback = tokenizer.bos_token_id or tokenizer.eos_token_id
t5_model.config.decoder_start_token_id = fallback
if t5_model.config.bos_token_id is None:
t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id
# Projection head
ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
ckpt = torch.load(ckpt_file, map_location="cpu")
proj_state = ckpt["projection"]
projection = torch.nn.Sequential(
torch.nn.Linear(768, t5_model.config.d_model),
torch.nn.LayerNorm(t5_model.config.d_model),
torch.nn.GELU()
).to("cpu")
projection.load_state_dict(proj_state)
# ── 2) INFERENCE FUNCTION ─────────────────────────────────────────────
def infer(image_path, json_file):
img_name = os.path.basename(image_path)
# Load NDJSON
data = []
with open(json_file.name, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
data.append(json.loads(line))
entry = next((e for e in data if e.get("img_name") == img_name), None)
if entry is None:
return f"❌ No JSON entry found for image '{img_name}'"
words = entry.get("src_word_list", [])
boxes = entry.get("src_wordbox_list", [])
# Preprocess image + tokens
img = Image.open(image_path).convert("RGB")
encoding = processor(
[img], [words], boxes=[boxes],
return_tensors="pt", padding=True, truncation=True
)
pixel_values = encoding.pixel_values.to("cpu")
input_ids = encoding.input_ids.to("cpu")
attention_mask = encoding.attention_mask.to("cpu")
bbox = encoding.bbox.to("cpu")
# Forward pass
with torch.no_grad():
# LayoutLMv3 encoding
lm_out = layout_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
bbox=bbox
)
seq_len = input_ids.size(1)
text_feats = lm_out.last_hidden_state[:, :seq_len, :]
# Projection + T5 decoding
proj_feats = projection(text_feats)
gen_ids = t5_model.generate(
inputs_embeds=proj_feats,
attention_mask=attention_mask,
max_length=512,
decoder_start_token_id=t5_model.config.decoder_start_token_id,
bos_token_id=t5_model.config.bos_token_id
)
# Decode and return
result = tokenizer.batch_decode(
gen_ids, skip_special_tokens=True
)[0]
return result
# ── 3) GRADIO INTERFACE ────────────────────────────────────────────────
demo = gr.Interface(
fn=infer,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.File(label="Upload JSON (NDJSON)")
],
outputs="text",
title="OCR Reorder Pipeline"
)
if __name__ == "__main__":
demo.launch(share=True)