brand-new-space / app.py
atnanahidiw's picture
Update app.py
f866fba verified
raw
history blame
1.96 kB
import gradio as gr
from transformers import DonutProcessor, VisionEncoderDecoderModel
processor = DonutProcessor.from_pretrained("nielsr/donut-demo")
model = VisionEncoderDecoderModel.from_pretrained("nielsr/donut-demo")
def donut(input_img):
# prepare encoder inputs
pixel_values = processor(sample["image"].convert("RGB"), return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids = decoder_input_ids.to(device)
# autoregressively generate sequence
model = model.to(device)
return model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
def parse_json(outputs):
seq = processor.batch_decode(outputs.sequences)[0]
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
return processor.token2json(seq)
def predict(input_img):
outputs = donut(input_img)
result = parse_json(outputs)
return result
gradio_app = gr.Interface(
predict,
inputs=gr.Image(label="Upload gambar dokumen", sources=['upload', 'webcam'], type="pil"),
outputs=[gr.JSON(label="Hasil")],
title="OCR Dokumen Identitas Indonesia",
description="Ekstraksi gambar KTP, SIM, Paspor, KK, dan NPWP menjadi data teks tersturktur",
)
if __name__ == "__main__":
gradio_app.launch(share=True)