uartimcs commited on
Commit
b719e63
·
verified ·
1 Parent(s): 8d92410

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -26
app.py CHANGED
@@ -1,26 +1,23 @@
1
- import gradio as gr
2
- import argparse
3
- import torch
4
- from PIL import Image
5
- from donut import DonutModel
6
- def demo_process(input_img):
7
- global model, task_prompt, task_name
8
- input_img = Image.fromarray(input_img)
9
- output = model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
10
- return output
11
- parser = argparse.ArgumentParser()
12
- parser.add_argument("--task", type=str, default="Booking")
13
- parser.add_argument("--pretrained_path", type=str, default="result/train_booking/20241112_150925")
14
- args, left_argv = parser.parse_known_args()
15
- task_name = args.task
16
- task_prompt = f"<s_{task_name}>"
17
- model = DonutModel.from_pretrained("./result/train_booking/20241112_150925")
18
- if torch.cuda.is_available():
19
- model.half()
20
- device = torch.device("cuda")
21
- model.to(device)
22
- else:
23
- model.encoder.to(torch.bfloat16)
24
- model.eval()
25
- demo = gr.Interface(fn=demo_process,inputs="image",outputs="json", title=f"Donut 🍩 demonstration for `{task_name}` task",)
26
- demo.launch(debug=True)
 
1
+ import gradio as gr
2
+ import argparse
3
+ import torch
4
+ from PIL import Image
5
+ from donut import DonutModel
6
+ def demo_process(input_img):
7
+ global model, task_prompt, task_name
8
+ input_img = Image.fromarray(input_img)
9
+ output = model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
10
+ return output
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--task", type=str, default="SGSInvoice")
13
+ parser.add_argument("--pretrained_path", type=str, default="uartimcs/donut-invoice-extract")
14
+ args, left_argv = parser.parse_known_args()
15
+ task_name = args.task
16
+ task_prompt = f"<s_{task_name}>"
17
+
18
+
19
+
20
+ model = DonutModel.from_pretrained("uartimcs/donut-invoice-extract")
21
+ model.eval()
22
+ demo = gr.Interface(fn=demo_process,inputs="image",outputs="json", title=f"Donut 🍩 demonstration for `{task_name}` task",)
23
+ demo.launch()