ivelin commited on
Commit
246389f
·
1 Parent(s): 1b85d75

switch to click model

Browse files
Files changed (1) hide show
  1. app.py +29 -40
app.py CHANGED
@@ -6,10 +6,9 @@ import torch
6
  import html
7
  from transformers import DonutProcessor, VisionEncoderDecoderModel
8
 
9
- pretrained_repo_name = 'ivelin/donut-refexp-combined-v1'
10
  pretrained_revision = 'main'
11
- # revision: '348ddad8e958d370b7e341acd6050330faa0500f' # Iou = 0.47
12
- # revision: '41210d7c42a22e77711711ec45508a6b63ec380f' # : IoU=0.42
13
  # use 'main' for latest revision
14
  print(f"Loading model checkpoint: {pretrained_repo_name}")
15
 
@@ -31,7 +30,7 @@ def process_refexp(image: Image, prompt: str):
31
  pixel_values = processor(image, return_tensors="pt").pixel_values
32
 
33
  # prepare decoder inputs
34
- task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_bounding_box>"
35
  prompt = task_prompt.replace("{user_input}", prompt)
36
  decoder_input_ids = processor.tokenizer(
37
  prompt, add_special_tokens=False, return_tensors="pt").input_ids
@@ -61,37 +60,28 @@ def process_refexp(image: Image, prompt: str):
61
  fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
62
  seqjson = processor.token2json(sequence)
63
 
64
- # safeguard in case predicted sequence does not include a target_bounding_box token
65
- bbox = seqjson.get('target_bounding_box')
66
- if bbox is None:
67
  print(
68
- f"token2bbox seq has no predicted target_bounding_box, seq:{seq}")
69
- bbox = {"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0}
70
- return bbox
71
 
72
- print(f"predicted bounding box with text coordinates: {bbox}")
73
- # safeguard in case text prediction is missing some bounding box coordinates
74
  # or coordinates are not valid numeric values
75
  try:
76
- xmin = float(bbox.get("xmin", 0))
77
  except ValueError:
78
- xmin = 0
79
  try:
80
- ymin = float(bbox.get("ymin", 0))
81
  except ValueError:
82
- ymin = 0
83
- try:
84
- xmax = float(bbox.get("xmax", 1))
85
- except ValueError:
86
- xmax = 1
87
- try:
88
- ymax = float(bbox.get("ymax", 1))
89
- except ValueError:
90
- ymax = 1
91
  # replace str with float coords
92
- bbox = {"xmin": xmin, "ymin": ymin, "xmax": xmax,
93
- "ymax": ymax, "decoder output sequence": sequence}
94
- print(f"predicted bounding box with float coordinates: {bbox}")
95
 
96
  print(f"image object: {image}")
97
  print(f"image size: {image.size}")
@@ -99,26 +89,25 @@ def process_refexp(image: Image, prompt: str):
99
  print(f"image width, height: {width, height}")
100
  print(f"processed prompt: {prompt}")
101
 
102
- # safeguard in case text prediction is missing some bounding box coordinates
103
- xmin = math.floor(width*bbox["xmin"])
104
- ymin = math.floor(height*bbox["ymin"])
105
- xmax = math.floor(width*bbox["xmax"])
106
- ymax = math.floor(height*bbox["ymax"])
107
 
108
  print(
109
- f"to image pixel values: xmin, ymin, xmax, ymax: {xmin, ymin, xmax, ymax}")
110
-
111
- shape = [(xmin, ymin), (xmax, ymax)]
112
 
113
- # deaw bbox rectangle
114
  img1 = ImageDraw.Draw(image)
115
- img1.rectangle(shape, outline="green", width=5)
116
- img1.rectangle(shape, outline="white", width=2)
117
 
118
- return image, bbox
 
 
 
 
 
119
 
120
 
121
- title = "Demo: Donut 🍩 for UI RefExp (by GuardianUI)"
122
  description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on [UIBert RefExp](https://huggingface.co/datasets/ivelin/ui_refexp_saved) Dataset (UI Referring Expression). To use it, simply upload your image and type a prompt and click 'submit', or click one of the examples to load them. See the model training <a href='https://colab.research.google.com/github/ivelin/donut_ui_refexp/blob/main/Fine_tune_Donut_on_UI_RefExp.ipynb' target='_parent'>Colab Notebook</a> for this space. Read more at the links below."
123
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
124
  examples = [["example_1.jpg", "select the setting icon from top right corner"],
 
6
  import html
7
  from transformers import DonutProcessor, VisionEncoderDecoderModel
8
 
9
+ pretrained_repo_name = 'ivelin/donut-refexp-click'
10
  pretrained_revision = 'main'
11
+ # revision can be git commit hash, branch or tag
 
12
  # use 'main' for latest revision
13
  print(f"Loading model checkpoint: {pretrained_repo_name}")
14
 
 
30
  pixel_values = processor(image, return_tensors="pt").pixel_values
31
 
32
  # prepare decoder inputs
33
+ task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_center>"
34
  prompt = task_prompt.replace("{user_input}", prompt)
35
  decoder_input_ids = processor.tokenizer(
36
  prompt, add_special_tokens=False, return_tensors="pt").input_ids
 
60
  fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
61
  seqjson = processor.token2json(sequence)
62
 
63
+ # safeguard in case predicted sequence does not include a target_center token
64
+ center_point = seqjson.get('target_center')
65
+ if center_point is None:
66
  print(
67
+ f"predicted sequence has no target_center, seq:{sequence}")
68
+ center_point = {"x": 0, "y": 0}
69
+ return center_point
70
 
71
+ print(f"predicted center_point with text coordinates: {center_point}")
72
+ # safeguard in case text prediction is missing some center point coordinates
73
  # or coordinates are not valid numeric values
74
  try:
75
+ x = float(center_point.get("x", 0))
76
  except ValueError:
77
+ x = 0
78
  try:
79
+ y = float(center_point.get("y", 0))
80
  except ValueError:
81
+ y = 0
 
 
 
 
 
 
 
 
82
  # replace str with float coords
83
+ center_point = {"x": x, "y": y, "decoder output sequence": sequence}
84
+ print(f"predicted center_point with float coordinates: {center_point}")
 
85
 
86
  print(f"image object: {image}")
87
  print(f"image size: {image.size}")
 
89
  print(f"image width, height: {width, height}")
90
  print(f"processed prompt: {prompt}")
91
 
92
+ # safeguard in case text prediction is missing some center point coordinates
93
+ x = math.floor(width*center_point["x"])
94
+ y = math.floor(height*center_point["y"])
 
 
95
 
96
  print(
97
+ f"to image pixel values: x, y: {x, y}")
 
 
98
 
99
+ # draw center point circle
100
  img1 = ImageDraw.Draw(image)
 
 
101
 
102
+ r = 1
103
+ shape = [(x-r, y-r), (x+r, y+r)]
104
+ img1.ellipse(shape, outline="green", width=10)
105
+ img1.ellipse(shape, outline="white", width=5)
106
+
107
+ return image, center_point
108
 
109
 
110
+ title = "Demo: Donut 🍩 for UI RefExp - Center Point (by GuardianUI)"
111
  description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on [UIBert RefExp](https://huggingface.co/datasets/ivelin/ui_refexp_saved) Dataset (UI Referring Expression). To use it, simply upload your image and type a prompt and click 'submit', or click one of the examples to load them. See the model training <a href='https://colab.research.google.com/github/ivelin/donut_ui_refexp/blob/main/Fine_tune_Donut_on_UI_RefExp.ipynb' target='_parent'>Colab Notebook</a> for this space. Read more at the links below."
112
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
113
  examples = [["example_1.jpg", "select the setting icon from top right corner"],