File size: 2,260 Bytes
b701d44
 
 
5b9baff
 
 
b701d44
5b9baff
fabf362
5b9baff
4617aca
fabf362
5b9baff
ab9088f
 
a0040a5
fabf362
4617aca
fabf362
 
 
4617aca
2ebc710
fabf362
2ebc710
 
 
 
 
ab9088f
4617aca
fabf362
 
 
5b9baff
fabf362
 
ab9088f
4617aca
fabf362
4617aca
 
5b9baff
ab9088f
4617aca
a0040a5
b701d44
5b9baff
4617aca
5b9baff
 
 
b701d44
 
5b9baff
 
b701d44
5b9baff
 
 
a0040a5
5b9baff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import json
import base64
from io import BytesIO
from PIL import Image
import gradio as gr

from inference import OcrReorderPipeline
from transformers import AutoProcessor, LayoutLMv3Model, AutoTokenizer

# ── 1) Load model + tokenizer + processor ─────────────────────────
repo      = "Uddipan107/ocr-layoutlmv3-base-t5-small"
model     = LayoutLMv3Model.from_pretrained(repo)
tokenizer = AutoTokenizer.from_pretrained(repo, subfolder="preprocessor")
processor = AutoProcessor.from_pretrained(repo, subfolder="preprocessor", apply_ocr=False)
pipe      = OcrReorderPipeline(model, tokenizer, processor, device=0)

# ── 2) Inference function ──────────────────────────────────────────
def infer(image_path, json_file):
    img_name = os.path.basename(image_path)

    # Parse NDJSON entries from uploaded file
    data = []
    with open(json_file.name, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            data.append(json.loads(line))

    # Find matching entry for this image
    entry = next((e for e in data if e["img_name"] == img_name), None)
    if entry is None:
        return f"❌ No JSON entry found for image '{img_name}'"

    words = entry["src_word_list"]
    boxes = entry["src_wordbox_list"]

    # Read and encode image to base64
    img = Image.open(image_path).convert("RGB")
    buf = BytesIO()
    img.save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode()

    # Call pipeline with `inputs` keyword plus extra args
    reordered = pipe(inputs=b64, words=words, boxes=boxes)[0]
    return reordered

# ── 3) Gradio interface ─────────────────────────────────────────────
demo = gr.Interface(
    fn=infer,
    inputs=[
        gr.Image(type="filepath", label="Upload Image"),
        gr.File(label="Upload JSON (NDJSON)")
    ],
    outputs="text",
    title="OCR Reorder Pipeline"
)

if __name__ == "__main__":
    # set share=True if you want a public link
    demo.launch()