royaljackal commited on
Commit
549b3fa
·
verified ·
1 Parent(s): d621b69

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIN_BOX_WIDTH = 8 # Минимальная ширина текстовой области (в пикселях)
2
+ MIN_BOX_HEIGHT = 15 # Минимальная высота текстовой области (в пикселях)
3
+ MAX_PART_WIDTH = 600 # Максимальная ширина части строки (в пикселях)
4
+ BOX_HEIGHT_TOLERANCE = 8 # Максимальная разница между высотами текстовых областей для добавлению в строку (в пикселях)
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-printed")
9
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-printed")
10
+ model.to(device)
11
+
12
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
13
+
14
+ model_translation = T5ForConditionalGeneration.from_pretrained('utrobinmv/t5_translate_en_ru_zh_small_1024')
15
+ model_translation.to(device)
16
+ tokenizer_translation = T5Tokenizer.from_pretrained('utrobinmv/t5_translate_en_ru_zh_small_1024')
17
+
18
+ def get_text_from_images(images):
19
+ extracted_text = []
20
+ image_number = 0
21
+ for image in images:
22
+ image_number += 1
23
+ image_cv = np.array(image)
24
+ image_cv = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
25
+
26
+ gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
27
+ thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
28
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
29
+ bounding_boxes = [cv2.boundingRect(contour) for contour in contours]
30
+
31
+ def group_boxes_into_lines(boxes, tolerance=BOX_HEIGHT_TOLERANCE):
32
+ sorted_boxes = sorted(boxes, key=lambda box: box[1])
33
+
34
+ lines = []
35
+ current_line = []
36
+
37
+ for box in sorted_boxes:
38
+ x, y, w, h = box
39
+
40
+ if not current_line:
41
+ current_line.append(box)
42
+ else:
43
+ last_box = current_line[-1]
44
+ last_y = last_box[1]
45
+
46
+ if abs(y - last_y) <= tolerance:
47
+ current_line.append(box)
48
+ else:
49
+ lines.append(current_line)
50
+ current_line = [box]
51
+
52
+ if current_line:
53
+ lines.append(current_line)
54
+
55
+ return lines
56
+
57
+ lines = group_boxes_into_lines(bounding_boxes)
58
+
59
+ line_number = 0
60
+ for line in lines:
61
+ line_number += 1
62
+
63
+ x_coords = [box[0] for box in line]
64
+ y_coords = [box[1] for box in line]
65
+ widths = [box[2] for box in line]
66
+ heights = [box[3] for box in line]
67
+
68
+ x_min = min(x_coords)
69
+ y_min = min(y_coords)
70
+ x_max = max(x_coords[i] + widths[i] for i in range(len(line)))
71
+ y_max = max(y_coords[i] + heights[i] for i in range(len(line)))
72
+
73
+ line_image = image_cv[y_min:y_max, x_min:x_max]
74
+
75
+ if line_image.size == 0 or line_image.shape[0] < MIN_BOX_HEIGHT or line_image.shape[1] < MIN_BOX_WIDTH:
76
+ continue
77
+
78
+ parts = []
79
+
80
+ if line_image.shape[1] > MAX_PART_WIDTH:
81
+ num_parts = (line_image.shape[1] // MAX_PART_WIDTH) + 1
82
+ part_width = line_image.shape[1] // num_parts
83
+
84
+ for i in range(num_parts):
85
+ start_x = i * part_width
86
+ end_x = (i + 1) * part_width if i < num_parts - 1 else line_image.shape[1]
87
+ part = line_image[:, start_x:end_x]
88
+ parts.append(part)
89
+ else:
90
+ parts.append(line_image)
91
+
92
+ line_text = ""
93
+ part_number = 0
94
+
95
+ for part in parts:
96
+ part_number += 1
97
+ clear_output()
98
+ print(f"Images: {image_number}/{len(images)}")
99
+ print(f"Lines: {line_number}/{len(lines)}")
100
+ print(f"Parts: {part_number}/{len(parts)}")
101
+
102
+ part_image_pil = Image.fromarray(cv2.cvtColor(part, cv2.COLOR_BGR2RGB))
103
+ display(part_image_pil)
104
+ print("\n".join(extracted_text))
105
+
106
+ pixel_values = processor(part_image_pil, return_tensors="pt").pixel_values
107
+ pixel_values = pixel_values.to(device)
108
+ generated_ids = model.generate(pixel_values)
109
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
110
+
111
+ line_text += text
112
+
113
+ extracted_text.append(line_text)
114
+
115
+ final_text = "\n".join(extracted_text)
116
+ return final_text
117
+
118
+ def summarize(text, max_length=300, min_length=150):
119
+ result = summarizer(text, max_length=max_length, min_length=min_length, do_sample=False)
120
+ return result[0]['summary_text']
121
+
122
+ def translate(text):
123
+ prefix = 'translate to ru: '
124
+ src_text = prefix + text
125
+
126
+ input_ids = tokenizer_translation(src_text, return_tensors="pt")
127
+
128
+ generated_tokens = model_translation.generate(**input_ids.to(device))
129
+
130
+ result = tokenizer_translation.batch_decode(generated_tokens, skip_special_tokens=True)
131
+ return result[0]
132
+
133
+ def launch(images, language):
134
+ if images == None or not images:
135
+ return "No input provided."
136
+ raw_text = get_text_from_images(images)
137
+ summary = summarize(raw_text)
138
+ if language == "rus":
139
+ return translate(summary)
140
+ return summary
141
+
142
+ def pdf_to_image(pdf, index = 0):
143
+ images = convert_from_bytes(pdf)
144
+ if 0 <= index < len(images):
145
+ return [images[index]]
146
+ return []
147
+
148
+ def pdf_to_images(pdf):
149
+ images = convert_from_bytes(pdf)
150
+ return images
151
+
152
+ def process_pdf(pdf_file, process_mode, page_index, language):
153
+ if process_mode == "all":
154
+ return launch(pdf_to_images(pdf_file), language)
155
+ elif process_mode == "single":
156
+ return launch(pdf_to_image(pdf_file, page_index), language)
157
+
158
+ def process_images(images, language):
159
+ pil_images = []
160
+ for image in images:
161
+ pil_images.append(Image.open(image))
162
+ launch(pil_images, language)
163
+
164
+ class PrintToTextbox:
165
+ def __init__(self, textbox):
166
+ self.textbox = textbox
167
+ self.buffer = ""
168
+
169
+ def write(self, text):
170
+ self.buffer += text
171
+ self.textbox.update(self.buffer)
172
+
173
+ def flush(self):
174
+ pass
175
+
176
+ def update_page_index_visibility(process_mode):
177
+ if process_mode == "single":
178
+ return gr.update(visible=True)
179
+ else:
180
+ return gr.update(visible=False)
181
+
182
+ with gr.Blocks() as demo:
183
+ gr.Markdown("# PDF and Image Text Summarizer")
184
+ gr.Markdown("Upload a PDF file or images to extract and summarize text.")
185
+
186
+ language = gr.Radio(choices=["rus", "eng"], label="Select Language", value="rus")
187
+
188
+ with gr.Tabs():
189
+ with gr.TabItem("PDF"):
190
+ pdf_file = gr.File(label="Upload PDF File", type="binary")
191
+ process_mode = gr.Radio(choices=["all", "single"], label="Process Mode", value="all")
192
+ page_index = gr.Number(label="Page Index", value=0, precision=0, visible=False)
193
+ pdf_output = gr.Textbox(label="Extracted Text")
194
+ pdf_button = gr.Button("Extract Text from PDF")
195
+
196
+ with gr.TabItem("Images"):
197
+ images = gr.Files(label="Upload Images", file_types=["image"])
198
+ image_output = gr.Textbox(label="Extracted Text")
199
+ image_button = gr.Button("Extract Text from Images")
200
+
201
+ pdf_button.click(process_pdf, inputs=[pdf_file, process_mode, page_index, language], outputs=pdf_output)
202
+ image_button.click(process_images, inputs=[images, language], outputs=image_output)
203
+ process_mode.change(update_page_index_visibility, inputs=process_mode, outputs=page_index)
204
+
205
+ demo.launch(debug=True)