stupidog04 commited on
Commit
ddb4ddb
·
1 Parent(s): af0cbfe

Add enlarge box ratio to ui, plot for temporal profile

Browse files
Files changed (1) hide show
  1. app.py +55 -28
app.py CHANGED
@@ -9,6 +9,7 @@ from pathlib import Path
9
  import cv2
10
  import pandas as pd
11
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
12
 
13
 
14
  #torch.hub.download_url_to_file('https://github.com/AaronCWacker/Yggdrasil/blob/main/images/BeautyIsTruthTruthisBeauty.JPG', 'BeautyIsTruthTruthisBeauty.JPG')
@@ -22,6 +23,17 @@ torch.hub.download_url_to_file('https://github.com/JaidedAI/EasyOCR/raw/master/e
22
  torch.hub.download_url_to_file('https://i.imgur.com/mwQFd7G.jpeg', 'Hindi.jpeg')
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  def draw_boxes(image, bounds, color='yellow', width=2):
26
  draw = ImageDraw.Draw(image)
27
  for bound in bounds:
@@ -63,7 +75,20 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
63
  processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-printed')
64
  model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-printed').to(device)
65
 
66
- def inference(video, lang, time_step, full_scan, number_filter, use_trocr, period_index):
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  output = 'results.mp4'
68
  reader = easyocr.Reader(lang)
69
  bounds = []
@@ -73,7 +98,7 @@ def inference(video, lang, time_step, full_scan, number_filter, use_trocr, perio
73
  frame_rate = vidcap.get(cv2.CAP_PROP_FPS)
74
  output_frames = []
75
  temporal_profiles = []
76
- compress_mp4 = False
77
 
78
  # Get the positions of the largest boxes in the first frame
79
  bounds = reader.readtext(frame)
@@ -91,27 +116,24 @@ def inference(video, lang, time_step, full_scan, number_filter, use_trocr, perio
91
  # Match bboxes to position and store the text read by OCR
92
  while success:
93
  if count % (int(frame_rate * time_step)) == 0:
94
- if full_scan:
95
- bounds = reader.readtext(frame)
96
- for box in bounds:
 
97
  bbox_pos = box_position(box)
98
  for i, position in enumerate(positions):
99
  distance = np.linalg.norm(np.array(bbox_pos) - np.array(position))
100
  if distance < 50:
101
- temporal_profiles[i].append((count / frame_rate, box[1]))
102
- break
103
- else:
104
- for i, box in enumerate(largest_boxes):
105
- x1, y1 = box[0][0]
106
- x2, y2 = box[0][2]
107
- box_width = x2 - x1
108
- box_height = y2 - y1
109
- ratio = 0.2
110
- x1 = max(0, int(x1 - ratio * box_width))
111
- x2 = min(frame.shape[1], int(x2 + ratio * box_width))
112
- y1 = max(0, int(y1 - ratio * box_height))
113
- y2 = min(frame.shape[0], int(y2 + ratio * box_height))
114
- cropped_frame = frame[y1:y2, x1:x2]
115
  if use_trocr:
116
  pixel_values = processor(images=cropped_frame, return_tensors="pt").pixel_values
117
  generated_ids = model.generate(pixel_values.to(device))
@@ -154,10 +176,10 @@ def inference(video, lang, time_step, full_scan, number_filter, use_trocr, perio
154
  # Draw boxes with box indices in the first frame of the output video
155
  im = Image.fromarray(output_frames[0])
156
  draw = ImageDraw.Draw(im)
157
- font_size = 30
158
  font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"
159
  for i, box in enumerate(largest_boxes):
160
- draw.text((box_position(box)), f"Box {i+1}", fill='red', font=ImageFont.truetype(font_path, font_size))
161
 
162
  output_video.release()
163
  vidcap.release()
@@ -176,7 +198,10 @@ def inference(video, lang, time_step, full_scan, number_filter, use_trocr, perio
176
  df_list.append({"Box": f"Box {i+1}", "Time (s)": t, "Text": text})
177
  df_list.append({"Box": f"", "Time (s)": "", "Text": ""})
178
  df = pd.concat([pd.DataFrame(df_list)])
179
- return output, im, df
 
 
 
180
 
181
 
182
  title = '🖼️Video to Multilingual OCR👁️Gradio'
@@ -184,7 +209,7 @@ description = 'Multilingual OCR which works conveniently on all devices in multi
184
  article = "<p style='text-align: center'></p>"
185
 
186
  examples = [
187
- ['test.mp4',['en'],10,False,True,True,1]
188
  ]
189
 
190
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
@@ -205,16 +230,18 @@ gr.Interface(
205
  [
206
  gr.inputs.Video(label='Input Video'),
207
  gr.inputs.CheckboxGroup(choices, type="value", default=['en'], label='Language'),
208
- gr.inputs.Number(label='Time Step (in seconds)', default=1.0),
209
  gr.inputs.Checkbox(label='Full Screen Scan'),
210
- gr.inputs.Checkbox(label='Use TrOCR large (this is only available when Full Screen Scan is disable)'),
211
  gr.inputs.Checkbox(label='Number Filter (remove non-digit char and insert period)'),
212
- gr.inputs.Textbox(label="period position",default=1)
 
 
213
  ],
214
  [
215
  gr.outputs.Video(label='Output Video'),
216
  gr.outputs.Image(label='Output Preview', type='numpy'),
217
- gr.outputs.Dataframe(headers=['Box', 'Time (s)', 'Text'], type='pandas'),
 
218
  ],
219
  title=title,
220
  description=description,
@@ -222,4 +249,4 @@ gr.Interface(
222
  examples=examples,
223
  css=css,
224
  enable_queue=True
225
- ).launch(debug=True)
 
9
  import cv2
10
  import pandas as pd
11
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
12
+ import matplotlib.pyplot as plt
13
 
14
 
15
  #torch.hub.download_url_to_file('https://github.com/AaronCWacker/Yggdrasil/blob/main/images/BeautyIsTruthTruthisBeauty.JPG', 'BeautyIsTruthTruthisBeauty.JPG')
 
23
  torch.hub.download_url_to_file('https://i.imgur.com/mwQFd7G.jpeg', 'Hindi.jpeg')
24
 
25
 
26
+ def plot_temporal_profile(temporal_profile):
27
+ fig = plt.figure()
28
+ for i, profile in enumerate(temporal_profile):
29
+ x, y = zip(*profile)
30
+ plt.plot(x, y, label=f"Box {i+1}")
31
+ plt.title("Temporal Profiles")
32
+ plt.xlabel("Time (s)")
33
+ plt.ylabel("Value")
34
+ plt.legend()
35
+ return fig
36
+
37
  def draw_boxes(image, bounds, color='yellow', width=2):
38
  draw = ImageDraw.Draw(image)
39
  for bound in bounds:
 
75
  processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-printed')
76
  model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-printed').to(device)
77
 
78
+ def process_box(box, frame, enlarge_ratio):
79
+ x1, y1 = box[0][0]
80
+ x2, y2 = box[0][2]
81
+ enlarge_ratio = enlarge_ratio/2
82
+ box_width = x2 - x1
83
+ box_height = y2 - y1
84
+ x1 = max(0, int(x1 - enlarge_ratio * box_width))
85
+ x2 = min(frame.shape[1], int(x2 + enlarge_ratio * box_width))
86
+ y1 = max(0, int(y1 - enlarge_ratio * box_height))
87
+ y2 = min(frame.shape[0], int(y2 + enlarge_ratio * box_height))
88
+ cropped_frame = frame[y1:y2, x1:x2]
89
+ return cropped_frame
90
+
91
+ def inference(video, lang, full_scan, number_filter, use_trocr, time_step, period_index, box_enlarge_ratio=0.4):
92
  output = 'results.mp4'
93
  reader = easyocr.Reader(lang)
94
  bounds = []
 
98
  frame_rate = vidcap.get(cv2.CAP_PROP_FPS)
99
  output_frames = []
100
  temporal_profiles = []
101
+ compress_mp4 = True
102
 
103
  # Get the positions of the largest boxes in the first frame
104
  bounds = reader.readtext(frame)
 
116
  # Match bboxes to position and store the text read by OCR
117
  while success:
118
  if count % (int(frame_rate * time_step)) == 0:
119
+ bounds = reader.readtext(frame) if full_scan else largest_boxes
120
+ for i, box in enumerate(bounds):
121
+ if full_scan:
122
+ # Match box to previous box
123
  bbox_pos = box_position(box)
124
  for i, position in enumerate(positions):
125
  distance = np.linalg.norm(np.array(bbox_pos) - np.array(position))
126
  if distance < 50:
127
+ if use_trocr:
128
+ cropped_frame = process_box(box, frame, enlarge_ratio=box_enlarge_ratio)
129
+ pixel_values = processor(images=cropped_frame, return_tensors="pt").pixel_values
130
+ generated_ids = model.generate(pixel_values.to(device))
131
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
132
+ temporal_profiles[i].append((count / frame_rate, generated_text))
133
+ else:
134
+ temporal_profiles[i].append((count / frame_rate, box[1]))
135
+ else:
136
+ cropped_frame = process_box(box, frame, enlarge_ratio=box_enlarge_ratio)
 
 
 
 
137
  if use_trocr:
138
  pixel_values = processor(images=cropped_frame, return_tensors="pt").pixel_values
139
  generated_ids = model.generate(pixel_values.to(device))
 
176
  # Draw boxes with box indices in the first frame of the output video
177
  im = Image.fromarray(output_frames[0])
178
  draw = ImageDraw.Draw(im)
179
+ font_size = 50
180
  font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"
181
  for i, box in enumerate(largest_boxes):
182
+ draw.text((box_position(box)), f"{i+1}", fill='red', font=ImageFont.truetype(font_path, font_size))
183
 
184
  output_video.release()
185
  vidcap.release()
 
198
  df_list.append({"Box": f"Box {i+1}", "Time (s)": t, "Text": text})
199
  df_list.append({"Box": f"", "Time (s)": "", "Text": ""})
200
  df = pd.concat([pd.DataFrame(df_list)])
201
+
202
+ # generate the plot of temporal profile
203
+ plot_fig = plot_temporal_profile(temporal_profiles)
204
+ return output, im, plot_fig, df
205
 
206
 
207
  title = '🖼️Video to Multilingual OCR👁️Gradio'
 
209
  article = "<p style='text-align: center'></p>"
210
 
211
  examples = [
212
+ ['test.mp4',['en'],False,True,True,10,1,0.4]
213
  ]
214
 
215
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
 
230
  [
231
  gr.inputs.Video(label='Input Video'),
232
  gr.inputs.CheckboxGroup(choices, type="value", default=['en'], label='Language'),
 
233
  gr.inputs.Checkbox(label='Full Screen Scan'),
234
+ gr.inputs.Checkbox(label='Use TrOCR large'),
235
  gr.inputs.Checkbox(label='Number Filter (remove non-digit char and insert period)'),
236
+ gr.inputs.Number(label='Time Step (in seconds)', default=1.0),
237
+ gr.inputs.Number(label="period position",default=1),
238
+ gr.inputs.Number(label='Box enlarge ratio', default=0.4)
239
  ],
240
  [
241
  gr.outputs.Video(label='Output Video'),
242
  gr.outputs.Image(label='Output Preview', type='numpy'),
243
+ gr.Plot(label='Temporal Profile'),
244
+ gr.outputs.Dataframe(headers=['Box', 'Time (s)', 'Text'], type='pandas', max_rows=15)
245
  ],
246
  title=title,
247
  description=description,
 
249
  examples=examples,
250
  css=css,
251
  enable_queue=True
252
+ ).launch(debug=True, share=True)