File size: 4,044 Bytes
b701d44
 
 
5b9baff
 
 
419d02f
 
 
 
 
 
 
 
 
 
 
 
a01cae7
419d02f
 
 
 
 
 
a01cae7
 
 
 
0d4b0fc
 
 
a01cae7
0d4b0fc
a01cae7
 
 
 
 
 
 
 
 
 
419d02f
 
 
a01cae7
419d02f
fabf362
419d02f
fabf362
 
 
a01cae7
2ebc710
fabf362
2ebc710
419d02f
2ebc710
 
ab9088f
419d02f
fabf362
 
5b9baff
419d02f
 
ab9088f
a01cae7
fabf362
419d02f
 
 
 
a01cae7
 
 
 
419d02f
a01cae7
419d02f
 
a01cae7
419d02f
 
 
 
 
 
 
 
a01cae7
419d02f
a01cae7
419d02f
 
5a0deb7
a01cae7
 
419d02f
ab9088f
a01cae7
419d02f
 
 
 
5b9baff
a01cae7
5b9baff
 
 
b701d44
 
5b9baff
 
b701d44
5b9baff
 
 
419d02f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import json
import base64
from io import BytesIO
from PIL import Image
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from transformers import (
    AutoProcessor,
    LayoutLMv3Model,
    T5ForConditionalGeneration,
    AutoTokenizer
)

# ── 1) MODEL SETUP ─────────────────────────────────────────────────────
repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"

# Processor
processor = AutoProcessor.from_pretrained(
    repo,
    subfolder="preprocessor",
    apply_ocr=False
)

# Encoder & Decoder
layout_model = LayoutLMv3Model.from_pretrained(repo).to("cpu").eval()
t5_model     = T5ForConditionalGeneration.from_pretrained(repo).to("cpu").eval()
tokenizer    = AutoTokenizer.from_pretrained(
    repo, subfolder="preprocessor"
)

# Ensure decoder_start_token_id and bos_token_id are set
if t5_model.config.decoder_start_token_id is None:
    fallback = tokenizer.bos_token_id or tokenizer.eos_token_id
    t5_model.config.decoder_start_token_id = fallback
if t5_model.config.bos_token_id is None:
    t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id

# Projection head
ckpt_file   = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
ckpt        = torch.load(ckpt_file, map_location="cpu")
proj_state  = ckpt["projection"]
projection  = torch.nn.Sequential(
    torch.nn.Linear(768, t5_model.config.d_model),
    torch.nn.LayerNorm(t5_model.config.d_model),
    torch.nn.GELU()
).to("cpu")
projection.load_state_dict(proj_state)

# ── 2) INFERENCE FUNCTION ─────────────────────────────────────────────
def infer(image_path, json_file):
    img_name = os.path.basename(image_path)

    # Load NDJSON
    data = []
    with open(json_file.name, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            data.append(json.loads(line))

    entry = next((e for e in data if e.get("img_name") == img_name), None)
    if entry is None:
        return f"❌ No JSON entry found for image '{img_name}'"

    words = entry.get("src_word_list", [])
    boxes = entry.get("src_wordbox_list", [])

    # Preprocess image + tokens
    img = Image.open(image_path).convert("RGB")
    encoding = processor(
        [img], [words], boxes=[boxes],
        return_tensors="pt", padding=True, truncation=True
    )
    pixel_values   = encoding.pixel_values.to("cpu")
    input_ids      = encoding.input_ids.to("cpu")
    attention_mask = encoding.attention_mask.to("cpu")
    bbox           = encoding.bbox.to("cpu")

    # Forward pass
    with torch.no_grad():
        # LayoutLMv3 encoding
        lm_out     = layout_model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            bbox=bbox
        )
        seq_len    = input_ids.size(1)
        text_feats = lm_out.last_hidden_state[:, :seq_len, :]

        # Projection + T5 decoding
        proj_feats = projection(text_feats)
        gen_ids    = t5_model.generate(
            inputs_embeds=proj_feats,
            attention_mask=attention_mask,
            max_length=512,
            decoder_start_token_id=t5_model.config.decoder_start_token_id,
            bos_token_id=t5_model.config.bos_token_id
        )

    # Decode and return
    result = tokenizer.batch_decode(
        gen_ids, skip_special_tokens=True
    )[0]
    return result

# ── 3) GRADIO INTERFACE ────────────────────────────────────────────────
demo = gr.Interface(
    fn=infer,
    inputs=[
        gr.Image(type="filepath", label="Upload Image"),
        gr.File(label="Upload JSON (NDJSON)")
    ],
    outputs="text",
    title="OCR Reorder Pipeline"
)

if __name__ == "__main__":
    demo.launch(share=True)