uartimcs commited on
Commit
9894cf3
·
verified ·
1 Parent(s): 578c592

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -25
app.py CHANGED
@@ -1,26 +1,26 @@
1
- import gradio as gr
2
- import argparse
3
- import torch
4
- from PIL import Image
5
- from donut import DonutModel
6
- def demo_process(input_img):
7
- global model, task_prompt, task_name
8
- input_img = Image.fromarray(input_img)
9
- output = model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
10
- return output
11
- parser = argparse.ArgumentParser()
12
- parser.add_argument("--task", type=str, default="Booking")
13
- parser.add_argument("--pretrained_path", type=str, default="result/train_booking/20241112_150925")
14
- args, left_argv = parser.parse_known_args()
15
- task_name = args.task
16
- task_prompt = f"<s_{task_name}>"
17
- model = DonutModel.from_pretrained("./result/train_booking/20241112_150925")
18
- if torch.cuda.is_available():
19
- model.half()
20
- device = torch.device("cuda")
21
- model.to(device)
22
- else:
23
- model.encoder.to(torch.bfloat16)
24
- model.eval()
25
- demo = gr.Interface(fn=demo_process,inputs="image",outputs="json", title=f"Donut 🍩 demonstration for `{task_name}` task",)
26
  demo.launch(debug=True)
 
1
+ import gradio as gr
2
+ import argparse
3
+ import torch
4
+ from PIL import Image
5
+ from donut import DonutModel
6
+ def demo_process(input_img):
7
+ global model, task_prompt, task_name
8
+ input_img = Image.fromarray(input_img)
9
+ output = model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
10
+ return output
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--task", type=str, default="Booking")
13
+ parser.add_argument("--pretrained_path", type=str, default="uartimcs/donut-booking-extract")
14
+ args, left_argv = parser.parse_known_args()
15
+ task_name = args.task
16
+ task_prompt = f"<s_{task_name}>"
17
+ model = DonutModel.from_pretrained("uartimcs/donut-booking-extract")
18
+ if torch.cuda.is_available():
19
+ model.half()
20
+ device = torch.device("cuda")
21
+ model.to(device)
22
+ else:
23
+ model.encoder.to(torch.bfloat16)
24
+ model.eval()
25
+ demo = gr.Interface(fn=demo_process,inputs="image",outputs="json", title=f"Donut 🍩 demonstration for `{task_name}` task",)
26
  demo.launch(debug=True)