import os import json import base64 from io import BytesIO from PIL import Image import gradio as gr import torch from huggingface_hub import hf_hub_download from transformers import ( AutoProcessor, LayoutLMv3Model, T5ForConditionalGeneration, AutoTokenizer ) # ── 1) MODEL SETUP ───────────────────────────────────────────────────── repo = "Uddipan107/ocr-layoutlmv3-base-t5-small" # Processor for LayoutLMv3 processor = AutoProcessor.from_pretrained( repo, subfolder="preprocessor", apply_ocr=False ) # LayoutLMv3 encoder layout_model = LayoutLMv3Model.from_pretrained(repo) layout_model.eval() # T5 decoder & tokenizer t5_model = T5ForConditionalGeneration.from_pretrained(repo) t5_model.eval() tokenizer = AutoTokenizer.from_pretrained( repo, subfolder="preprocessor" ) # Ensure decoder_start_token_id is set if t5_model.config.decoder_start_token_id is None: # Fallback to bos_token_id if present t5_model.config.decoder_start_token_id = tokenizer.bos_token_id # Projection head: load from checkpoint ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin") ckpt = torch.load(ckpt_file, map_location="cpu") proj_state= ckpt["projection"] projection = torch.nn.Sequential( torch.nn.Linear(768, t5_model.config.d_model), torch.nn.LayerNorm(t5_model.config.d_model), torch.nn.GELU() ) projection.load_state_dict(proj_state) projection.eval() # Move models to CPU (Spaces are CPU-only) device = torch.device("cpu") layout_model.to(device) t5_model.to(device) projection.to(device) repo = "Uddipan107/ocr-layoutlmv3-base-t5-small" # Processor for LayoutLMv3 processor = AutoProcessor.from_pretrained( repo, subfolder="preprocessor", apply_ocr=False ) # LayoutLMv3 encoder layout_model = LayoutLMv3Model.from_pretrained(repo) layout_model.eval() # T5 decoder & tokenizer t5_model = T5ForConditionalGeneration.from_pretrained(repo) t5_model.eval() tokenizer = AutoTokenizer.from_pretrained( repo, subfolder="preprocessor" ) # Projection head: load from checkpoint ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin") ckpt = torch.load(ckpt_file, map_location="cpu") proj_state= ckpt["projection"] projection = torch.nn.Sequential( torch.nn.Linear(768, t5_model.config.d_model), torch.nn.LayerNorm(t5_model.config.d_model), torch.nn.GELU() ) projection.load_state_dict(proj_state) projection.eval() # Move models to CPU (Spaces are CPU-only) device = torch.device("cpu") layout_model.to(device) t5_model.to(device) projection.to(device) # ── 2) INFERENCE FUNCTION ───────────────────────────────────────────── def infer(image_path, json_file): img_name = os.path.basename(image_path) # 2.a) Load NDJSON file (one JSON object per line) data = [] with open(json_file.name, "r", encoding="utf-8") as f: for line in f: if not line.strip(): continue data.append(json.loads(line)) # 2.b) Find entry matching uploaded image entry = next((e for e in data if e.get("img_name") == img_name), None) if entry is None: return f"❌ No JSON entry found for image '{img_name}'" words = entry.get("src_word_list", []) boxes = entry.get("src_wordbox_list", []) # 2.c) Open and preprocess the image + tokens + boxes img = Image.open(image_path).convert("RGB") encoding = processor( [img], [words], boxes=[boxes], return_tensors="pt", padding=True, truncation=True ) pixel_values = encoding.pixel_values.to(device) input_ids = encoding.input_ids.to(device) attention_mask = encoding.attention_mask.to(device) bbox = encoding.bbox.to(device) # 2.d) Forward pass with torch.no_grad(): # LayoutLMv3 encoding lm_out = layout_model( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, bbox=bbox ) seq_len = input_ids.size(1) text_feats = lm_out.last_hidden_state[:, :seq_len, :] # Projection → T5 decoding proj_feats = projection(text_feats) gen_ids = t5_model.generate( inputs_embeds=proj_feats, attention_mask=attention_mask, max_length=512, decoder_start_token_id=t5_model.config.decoder_start_token_id ) # Decode to text result = tokenizer.batch_decode( gen_ids, skip_special_tokens=True )[0] return result # ── 3) GRADIO UI ─────────────────────────────────────────────────────── demo = gr.Interface( fn=infer, inputs=[ gr.Image(type="filepath", label="Upload Image"), gr.File(label="Upload JSON (NDJSON)") ], outputs="text", title="OCR Reorder Pipeline" ) if __name__ == "__main__": demo.launch(share=True)