Spaces:
Running
Running
import os | |
import json | |
import base64 | |
from io import BytesIO | |
from PIL import Image | |
import gradio as gr | |
import torch | |
from huggingface_hub import hf_hub_download | |
from transformers import ( | |
AutoProcessor, | |
LayoutLMv3Model, | |
T5ForConditionalGeneration, | |
AutoTokenizer | |
) | |
# ββ 1) MODEL SETUP βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
repo = "Uddipan107/ocr-layoutlmv3-base-t5-small" | |
# Processor for LayoutLMv3 | |
processor = AutoProcessor.from_pretrained( | |
repo, | |
subfolder="preprocessor", | |
apply_ocr=False | |
) | |
# LayoutLMv3 encoder | |
layout_model = LayoutLMv3Model.from_pretrained(repo) | |
layout_model.eval() | |
# T5 decoder & tokenizer | |
t5_model = T5ForConditionalGeneration.from_pretrained(repo) | |
t5_model.eval() | |
tokenizer = AutoTokenizer.from_pretrained( | |
repo, subfolder="preprocessor" | |
) | |
# Ensure decoder_start_token_id is set | |
if t5_model.config.decoder_start_token_id is None: | |
# Fallback to bos_token_id if present | |
t5_model.config.decoder_start_token_id = tokenizer.bos_token_id | |
# Projection head: load from checkpoint | |
ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin") | |
ckpt = torch.load(ckpt_file, map_location="cpu") | |
proj_state= ckpt["projection"] | |
projection = torch.nn.Sequential( | |
torch.nn.Linear(768, t5_model.config.d_model), | |
torch.nn.LayerNorm(t5_model.config.d_model), | |
torch.nn.GELU() | |
) | |
projection.load_state_dict(proj_state) | |
projection.eval() | |
# Move models to CPU (Spaces are CPU-only) | |
device = torch.device("cpu") | |
layout_model.to(device) | |
t5_model.to(device) | |
projection.to(device) | |
repo = "Uddipan107/ocr-layoutlmv3-base-t5-small" | |
# Processor for LayoutLMv3 | |
processor = AutoProcessor.from_pretrained( | |
repo, | |
subfolder="preprocessor", | |
apply_ocr=False | |
) | |
# LayoutLMv3 encoder | |
layout_model = LayoutLMv3Model.from_pretrained(repo) | |
layout_model.eval() | |
# T5 decoder & tokenizer | |
t5_model = T5ForConditionalGeneration.from_pretrained(repo) | |
t5_model.eval() | |
tokenizer = AutoTokenizer.from_pretrained( | |
repo, subfolder="preprocessor" | |
) | |
# Projection head: load from checkpoint | |
ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin") | |
ckpt = torch.load(ckpt_file, map_location="cpu") | |
proj_state= ckpt["projection"] | |
projection = torch.nn.Sequential( | |
torch.nn.Linear(768, t5_model.config.d_model), | |
torch.nn.LayerNorm(t5_model.config.d_model), | |
torch.nn.GELU() | |
) | |
projection.load_state_dict(proj_state) | |
projection.eval() | |
# Move models to CPU (Spaces are CPU-only) | |
device = torch.device("cpu") | |
layout_model.to(device) | |
t5_model.to(device) | |
projection.to(device) | |
# ββ 2) INFERENCE FUNCTION βββββββββββββββββββββββββββββββββββββββββββββ | |
def infer(image_path, json_file): | |
img_name = os.path.basename(image_path) | |
# 2.a) Load NDJSON file (one JSON object per line) | |
data = [] | |
with open(json_file.name, "r", encoding="utf-8") as f: | |
for line in f: | |
if not line.strip(): | |
continue | |
data.append(json.loads(line)) | |
# 2.b) Find entry matching uploaded image | |
entry = next((e for e in data if e.get("img_name") == img_name), None) | |
if entry is None: | |
return f"β No JSON entry found for image '{img_name}'" | |
words = entry.get("src_word_list", []) | |
boxes = entry.get("src_wordbox_list", []) | |
# 2.c) Open and preprocess the image + tokens + boxes | |
img = Image.open(image_path).convert("RGB") | |
encoding = processor( | |
[img], [words], boxes=[boxes], | |
return_tensors="pt", padding=True, truncation=True | |
) | |
pixel_values = encoding.pixel_values.to(device) | |
input_ids = encoding.input_ids.to(device) | |
attention_mask = encoding.attention_mask.to(device) | |
bbox = encoding.bbox.to(device) | |
# 2.d) Forward pass | |
with torch.no_grad(): | |
# LayoutLMv3 encoding | |
lm_out = layout_model( | |
pixel_values=pixel_values, | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
bbox=bbox | |
) | |
seq_len = input_ids.size(1) | |
text_feats = lm_out.last_hidden_state[:, :seq_len, :] | |
# Projection β T5 decoding | |
proj_feats = projection(text_feats) | |
gen_ids = t5_model.generate( | |
inputs_embeds=proj_feats, | |
attention_mask=attention_mask, | |
max_length=512, | |
decoder_start_token_id=t5_model.config.decoder_start_token_id | |
) | |
# Decode to text | |
result = tokenizer.batch_decode( | |
gen_ids, skip_special_tokens=True | |
)[0] | |
return result | |
# ββ 3) GRADIO UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
demo = gr.Interface( | |
fn=infer, | |
inputs=[ | |
gr.Image(type="filepath", label="Upload Image"), | |
gr.File(label="Upload JSON (NDJSON)") | |
], | |
outputs="text", | |
title="OCR Reorder Pipeline" | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |