Spaces:
Runtime error
Runtime error
Commit
·
ddb4ddb
1
Parent(s):
af0cbfe
Add enlarge box ratio to ui, plot for temporal profile
Browse files
app.py
CHANGED
@@ -9,6 +9,7 @@ from pathlib import Path
|
|
9 |
import cv2
|
10 |
import pandas as pd
|
11 |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
|
12 |
|
13 |
|
14 |
#torch.hub.download_url_to_file('https://github.com/AaronCWacker/Yggdrasil/blob/main/images/BeautyIsTruthTruthisBeauty.JPG', 'BeautyIsTruthTruthisBeauty.JPG')
|
@@ -22,6 +23,17 @@ torch.hub.download_url_to_file('https://github.com/JaidedAI/EasyOCR/raw/master/e
|
|
22 |
torch.hub.download_url_to_file('https://i.imgur.com/mwQFd7G.jpeg', 'Hindi.jpeg')
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def draw_boxes(image, bounds, color='yellow', width=2):
|
26 |
draw = ImageDraw.Draw(image)
|
27 |
for bound in bounds:
|
@@ -63,7 +75,20 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
63 |
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-printed')
|
64 |
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-printed').to(device)
|
65 |
|
66 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
output = 'results.mp4'
|
68 |
reader = easyocr.Reader(lang)
|
69 |
bounds = []
|
@@ -73,7 +98,7 @@ def inference(video, lang, time_step, full_scan, number_filter, use_trocr, perio
|
|
73 |
frame_rate = vidcap.get(cv2.CAP_PROP_FPS)
|
74 |
output_frames = []
|
75 |
temporal_profiles = []
|
76 |
-
compress_mp4 =
|
77 |
|
78 |
# Get the positions of the largest boxes in the first frame
|
79 |
bounds = reader.readtext(frame)
|
@@ -91,27 +116,24 @@ def inference(video, lang, time_step, full_scan, number_filter, use_trocr, perio
|
|
91 |
# Match bboxes to position and store the text read by OCR
|
92 |
while success:
|
93 |
if count % (int(frame_rate * time_step)) == 0:
|
94 |
-
if full_scan
|
95 |
-
|
96 |
-
|
|
|
97 |
bbox_pos = box_position(box)
|
98 |
for i, position in enumerate(positions):
|
99 |
distance = np.linalg.norm(np.array(bbox_pos) - np.array(position))
|
100 |
if distance < 50:
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
x2 = min(frame.shape[1], int(x2 + ratio * box_width))
|
112 |
-
y1 = max(0, int(y1 - ratio * box_height))
|
113 |
-
y2 = min(frame.shape[0], int(y2 + ratio * box_height))
|
114 |
-
cropped_frame = frame[y1:y2, x1:x2]
|
115 |
if use_trocr:
|
116 |
pixel_values = processor(images=cropped_frame, return_tensors="pt").pixel_values
|
117 |
generated_ids = model.generate(pixel_values.to(device))
|
@@ -154,10 +176,10 @@ def inference(video, lang, time_step, full_scan, number_filter, use_trocr, perio
|
|
154 |
# Draw boxes with box indices in the first frame of the output video
|
155 |
im = Image.fromarray(output_frames[0])
|
156 |
draw = ImageDraw.Draw(im)
|
157 |
-
font_size =
|
158 |
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"
|
159 |
for i, box in enumerate(largest_boxes):
|
160 |
-
draw.text((box_position(box)), f"
|
161 |
|
162 |
output_video.release()
|
163 |
vidcap.release()
|
@@ -176,7 +198,10 @@ def inference(video, lang, time_step, full_scan, number_filter, use_trocr, perio
|
|
176 |
df_list.append({"Box": f"Box {i+1}", "Time (s)": t, "Text": text})
|
177 |
df_list.append({"Box": f"", "Time (s)": "", "Text": ""})
|
178 |
df = pd.concat([pd.DataFrame(df_list)])
|
179 |
-
|
|
|
|
|
|
|
180 |
|
181 |
|
182 |
title = '🖼️Video to Multilingual OCR👁️Gradio'
|
@@ -184,7 +209,7 @@ description = 'Multilingual OCR which works conveniently on all devices in multi
|
|
184 |
article = "<p style='text-align: center'></p>"
|
185 |
|
186 |
examples = [
|
187 |
-
['test.mp4',['en'],
|
188 |
]
|
189 |
|
190 |
css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
|
@@ -205,16 +230,18 @@ gr.Interface(
|
|
205 |
[
|
206 |
gr.inputs.Video(label='Input Video'),
|
207 |
gr.inputs.CheckboxGroup(choices, type="value", default=['en'], label='Language'),
|
208 |
-
gr.inputs.Number(label='Time Step (in seconds)', default=1.0),
|
209 |
gr.inputs.Checkbox(label='Full Screen Scan'),
|
210 |
-
gr.inputs.Checkbox(label='Use TrOCR large
|
211 |
gr.inputs.Checkbox(label='Number Filter (remove non-digit char and insert period)'),
|
212 |
-
gr.inputs.
|
|
|
|
|
213 |
],
|
214 |
[
|
215 |
gr.outputs.Video(label='Output Video'),
|
216 |
gr.outputs.Image(label='Output Preview', type='numpy'),
|
217 |
-
gr.
|
|
|
218 |
],
|
219 |
title=title,
|
220 |
description=description,
|
@@ -222,4 +249,4 @@ gr.Interface(
|
|
222 |
examples=examples,
|
223 |
css=css,
|
224 |
enable_queue=True
|
225 |
-
).launch(debug=True)
|
|
|
9 |
import cv2
|
10 |
import pandas as pd
|
11 |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
|
14 |
|
15 |
#torch.hub.download_url_to_file('https://github.com/AaronCWacker/Yggdrasil/blob/main/images/BeautyIsTruthTruthisBeauty.JPG', 'BeautyIsTruthTruthisBeauty.JPG')
|
|
|
23 |
torch.hub.download_url_to_file('https://i.imgur.com/mwQFd7G.jpeg', 'Hindi.jpeg')
|
24 |
|
25 |
|
26 |
+
def plot_temporal_profile(temporal_profile):
|
27 |
+
fig = plt.figure()
|
28 |
+
for i, profile in enumerate(temporal_profile):
|
29 |
+
x, y = zip(*profile)
|
30 |
+
plt.plot(x, y, label=f"Box {i+1}")
|
31 |
+
plt.title("Temporal Profiles")
|
32 |
+
plt.xlabel("Time (s)")
|
33 |
+
plt.ylabel("Value")
|
34 |
+
plt.legend()
|
35 |
+
return fig
|
36 |
+
|
37 |
def draw_boxes(image, bounds, color='yellow', width=2):
|
38 |
draw = ImageDraw.Draw(image)
|
39 |
for bound in bounds:
|
|
|
75 |
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-printed')
|
76 |
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-printed').to(device)
|
77 |
|
78 |
+
def process_box(box, frame, enlarge_ratio):
|
79 |
+
x1, y1 = box[0][0]
|
80 |
+
x2, y2 = box[0][2]
|
81 |
+
enlarge_ratio = enlarge_ratio/2
|
82 |
+
box_width = x2 - x1
|
83 |
+
box_height = y2 - y1
|
84 |
+
x1 = max(0, int(x1 - enlarge_ratio * box_width))
|
85 |
+
x2 = min(frame.shape[1], int(x2 + enlarge_ratio * box_width))
|
86 |
+
y1 = max(0, int(y1 - enlarge_ratio * box_height))
|
87 |
+
y2 = min(frame.shape[0], int(y2 + enlarge_ratio * box_height))
|
88 |
+
cropped_frame = frame[y1:y2, x1:x2]
|
89 |
+
return cropped_frame
|
90 |
+
|
91 |
+
def inference(video, lang, full_scan, number_filter, use_trocr, time_step, period_index, box_enlarge_ratio=0.4):
|
92 |
output = 'results.mp4'
|
93 |
reader = easyocr.Reader(lang)
|
94 |
bounds = []
|
|
|
98 |
frame_rate = vidcap.get(cv2.CAP_PROP_FPS)
|
99 |
output_frames = []
|
100 |
temporal_profiles = []
|
101 |
+
compress_mp4 = True
|
102 |
|
103 |
# Get the positions of the largest boxes in the first frame
|
104 |
bounds = reader.readtext(frame)
|
|
|
116 |
# Match bboxes to position and store the text read by OCR
|
117 |
while success:
|
118 |
if count % (int(frame_rate * time_step)) == 0:
|
119 |
+
bounds = reader.readtext(frame) if full_scan else largest_boxes
|
120 |
+
for i, box in enumerate(bounds):
|
121 |
+
if full_scan:
|
122 |
+
# Match box to previous box
|
123 |
bbox_pos = box_position(box)
|
124 |
for i, position in enumerate(positions):
|
125 |
distance = np.linalg.norm(np.array(bbox_pos) - np.array(position))
|
126 |
if distance < 50:
|
127 |
+
if use_trocr:
|
128 |
+
cropped_frame = process_box(box, frame, enlarge_ratio=box_enlarge_ratio)
|
129 |
+
pixel_values = processor(images=cropped_frame, return_tensors="pt").pixel_values
|
130 |
+
generated_ids = model.generate(pixel_values.to(device))
|
131 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
132 |
+
temporal_profiles[i].append((count / frame_rate, generated_text))
|
133 |
+
else:
|
134 |
+
temporal_profiles[i].append((count / frame_rate, box[1]))
|
135 |
+
else:
|
136 |
+
cropped_frame = process_box(box, frame, enlarge_ratio=box_enlarge_ratio)
|
|
|
|
|
|
|
|
|
137 |
if use_trocr:
|
138 |
pixel_values = processor(images=cropped_frame, return_tensors="pt").pixel_values
|
139 |
generated_ids = model.generate(pixel_values.to(device))
|
|
|
176 |
# Draw boxes with box indices in the first frame of the output video
|
177 |
im = Image.fromarray(output_frames[0])
|
178 |
draw = ImageDraw.Draw(im)
|
179 |
+
font_size = 50
|
180 |
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"
|
181 |
for i, box in enumerate(largest_boxes):
|
182 |
+
draw.text((box_position(box)), f"{i+1}", fill='red', font=ImageFont.truetype(font_path, font_size))
|
183 |
|
184 |
output_video.release()
|
185 |
vidcap.release()
|
|
|
198 |
df_list.append({"Box": f"Box {i+1}", "Time (s)": t, "Text": text})
|
199 |
df_list.append({"Box": f"", "Time (s)": "", "Text": ""})
|
200 |
df = pd.concat([pd.DataFrame(df_list)])
|
201 |
+
|
202 |
+
# generate the plot of temporal profile
|
203 |
+
plot_fig = plot_temporal_profile(temporal_profiles)
|
204 |
+
return output, im, plot_fig, df
|
205 |
|
206 |
|
207 |
title = '🖼️Video to Multilingual OCR👁️Gradio'
|
|
|
209 |
article = "<p style='text-align: center'></p>"
|
210 |
|
211 |
examples = [
|
212 |
+
['test.mp4',['en'],False,True,True,10,1,0.4]
|
213 |
]
|
214 |
|
215 |
css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
|
|
|
230 |
[
|
231 |
gr.inputs.Video(label='Input Video'),
|
232 |
gr.inputs.CheckboxGroup(choices, type="value", default=['en'], label='Language'),
|
|
|
233 |
gr.inputs.Checkbox(label='Full Screen Scan'),
|
234 |
+
gr.inputs.Checkbox(label='Use TrOCR large'),
|
235 |
gr.inputs.Checkbox(label='Number Filter (remove non-digit char and insert period)'),
|
236 |
+
gr.inputs.Number(label='Time Step (in seconds)', default=1.0),
|
237 |
+
gr.inputs.Number(label="period position",default=1),
|
238 |
+
gr.inputs.Number(label='Box enlarge ratio', default=0.4)
|
239 |
],
|
240 |
[
|
241 |
gr.outputs.Video(label='Output Video'),
|
242 |
gr.outputs.Image(label='Output Preview', type='numpy'),
|
243 |
+
gr.Plot(label='Temporal Profile'),
|
244 |
+
gr.outputs.Dataframe(headers=['Box', 'Time (s)', 'Text'], type='pandas', max_rows=15)
|
245 |
],
|
246 |
title=title,
|
247 |
description=description,
|
|
|
249 |
examples=examples,
|
250 |
css=css,
|
251 |
enable_queue=True
|
252 |
+
).launch(debug=True, share=True)
|