royaljackal's picture
Update app.py
bf9402d verified
from transformers import pipeline, TrOCRProcessor, VisionEncoderDecoderModel, T5ForConditionalGeneration, T5Tokenizer
from pdf2image import convert_from_path, convert_from_bytes
from IPython.display import clear_output
from PIL import Image
import cv2
import numpy as np
import torch
import gradio as gr
MIN_BOX_WIDTH = 8 # Минимальная ширина текстовой области (в пикселях)
MIN_BOX_HEIGHT = 15 # Минимальная высота текстовой области (в пикселях)
MAX_PART_WIDTH = 600 # Максимальная ширина части строки (в пикселях)
BOX_HEIGHT_TOLERANCE = 8 # Максимальная разница между высотами текстовых областей для добавлению в строку (в пикселях)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-printed")
model.to(device)
summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
model_translation = T5ForConditionalGeneration.from_pretrained('utrobinmv/t5_translate_en_ru_zh_small_1024')
model_translation.to(device)
tokenizer_translation = T5Tokenizer.from_pretrained('utrobinmv/t5_translate_en_ru_zh_small_1024')
def get_text_from_images(images):
extracted_text = []
image_number = 0
for image in images:
image_number += 1
image_cv = np.array(image)
image_cv = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
bounding_boxes = [cv2.boundingRect(contour) for contour in contours]
def group_boxes_into_lines(boxes, tolerance=BOX_HEIGHT_TOLERANCE):
sorted_boxes = sorted(boxes, key=lambda box: box[1])
lines = []
current_line = []
for box in sorted_boxes:
x, y, w, h = box
if not current_line:
current_line.append(box)
else:
last_box = current_line[-1]
last_y = last_box[1]
if abs(y - last_y) <= tolerance:
current_line.append(box)
else:
lines.append(current_line)
current_line = [box]
if current_line:
lines.append(current_line)
return lines
lines = group_boxes_into_lines(bounding_boxes)
line_number = 0
for line in lines:
line_number += 1
x_coords = [box[0] for box in line]
y_coords = [box[1] for box in line]
widths = [box[2] for box in line]
heights = [box[3] for box in line]
x_min = min(x_coords)
y_min = min(y_coords)
x_max = max(x_coords[i] + widths[i] for i in range(len(line)))
y_max = max(y_coords[i] + heights[i] for i in range(len(line)))
line_image = image_cv[y_min:y_max, x_min:x_max]
if line_image.size == 0 or line_image.shape[0] < MIN_BOX_HEIGHT or line_image.shape[1] < MIN_BOX_WIDTH:
continue
parts = []
if line_image.shape[1] > MAX_PART_WIDTH:
num_parts = (line_image.shape[1] // MAX_PART_WIDTH) + 1
part_width = line_image.shape[1] // num_parts
for i in range(num_parts):
start_x = i * part_width
end_x = (i + 1) * part_width if i < num_parts - 1 else line_image.shape[1]
part = line_image[:, start_x:end_x]
parts.append(part)
else:
parts.append(line_image)
line_text = ""
part_number = 0
for part in parts:
part_number += 1
#clear_output()
print(f"Images: {image_number}/{len(images)}")
print(f"Lines: {line_number}/{len(lines)}")
print(f"Parts: {part_number}/{len(parts)}")
part_image_pil = Image.fromarray(cv2.cvtColor(part, cv2.COLOR_BGR2RGB))
#display(part_image_pil)
print("\n".join(extracted_text))
pixel_values = processor(part_image_pil, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
generated_ids = model.generate(pixel_values)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
line_text += text
extracted_text.append(line_text)
final_text = "\n".join(extracted_text)
return final_text
def summarize(text, max_length=300, min_length=150):
result = summarizer(text, max_length=max_length, min_length=min_length, do_sample=False)
return result[0]['summary_text']
def translate(text):
prefix = 'translate to ru: '
src_text = prefix + text
input_ids = tokenizer_translation(src_text, return_tensors="pt")
generated_tokens = model_translation.generate(**input_ids.to(device))
result = tokenizer_translation.batch_decode(generated_tokens, skip_special_tokens=True)
return result[0]
def launch(images, language):
if images == None or not images:
return "No input provided."
raw_text = get_text_from_images(images)
summary = summarize(raw_text)
if language == "rus":
return translate(summary)
return summary
def pdf_to_image(pdf, index = 0):
images = convert_from_bytes(pdf)
if 0 <= index < len(images):
return [images[index]]
return []
def pdf_to_images(pdf):
images = convert_from_bytes(pdf)
return images
def process_pdf(pdf_file, process_mode, page_index, language):
if process_mode == "all":
return launch(pdf_to_images(pdf_file), language)
elif process_mode == "single":
return launch(pdf_to_image(pdf_file, page_index), language)
def process_images(images, language):
pil_images = []
for image in images:
pil_images.append(Image.open(image))
launch(pil_images, language)
class PrintToTextbox:
def __init__(self, textbox):
self.textbox = textbox
self.buffer = ""
def write(self, text):
self.buffer += text
self.textbox.update(self.buffer)
def flush(self):
pass
def update_page_index_visibility(process_mode):
if process_mode == "single":
return gr.update(visible=True)
else:
return gr.update(visible=False)
with gr.Blocks() as demo:
gr.Markdown("# PDF and Image Text Summarizer")
gr.Markdown("Upload a PDF file or images to extract and summarize text.")
gr.Markdown("Takes about 10 minutes per page.")
language = gr.Radio(choices=["rus", "eng"], label="Output Language", value="rus")
with gr.Tabs():
with gr.TabItem("PDF"):
pdf_file = gr.File(label="Upload PDF File", type="binary")
process_mode = gr.Radio(choices=["single", "all"], label="Process Mode", value="single")
page_index = gr.Number(label="Page Index", value=0, precision=0)
pdf_output = gr.Textbox(label="Extracted Text")
pdf_button = gr.Button("Extract Text from PDF")
with gr.TabItem("Images"):
images = gr.Files(label="Upload Images", file_types=["image"])
image_output = gr.Textbox(label="Extracted Text")
image_button = gr.Button("Extract Text from Images")
pdf_button.click(process_pdf, inputs=[pdf_file, process_mode, page_index, language], outputs=pdf_output)
image_button.click(process_images, inputs=[images, language], outputs=image_output)
process_mode.change(update_page_index_visibility, inputs=process_mode, outputs=page_index)
demo.launch(debug=True)