brand-new-space / app.py
atnanahidiw's picture
Update app.py
4467b70 verified
raw
history blame
2.1 kB
import re
import torch
import gradio as gr
from transformers import DonutProcessor, VisionEncoderDecoderModel
processor = DonutProcessor.from_pretrained("atnanahidiw/donut-poc-1.1.1")
model = VisionEncoderDecoderModel.from_pretrained("atnanahidiw/donut-poc-1.1.1")
device = "cuda" if torch.cuda.is_available() else "cpu"
def donut(sample):
# prepare encoder inputs
pixel_values = processor(sample.convert("RGB"), return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids = decoder_input_ids.to(device)
# autoregressively generate sequence
model_run = model.to(device)
return model_run.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
def parse_json(outputs):
seq = processor.batch_decode(outputs.sequences)[0]
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
return processor.token2json(seq)
def predict(input_img):
outputs = donut(input_img)
result = parse_json(outputs)
return result
gradio_app = gr.Interface(
predict,
inputs=gr.Image(label="Upload gambar dokumen", sources=['upload', 'webcam'], type="pil"),
outputs=[gr.JSON(label="Hasil")],
title="OCR Dokumen Identitas Indonesia",
description="Ekstraksi gambar dokumen identitas indonesia menjadi data teks terstruktur (KTP βœ…, SIM βœ…, Paspor βœ…, NPWP βœ…, dan KK)",
)
if __name__ == "__main__":
gradio_app.launch(share=True)