Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Create app.py
Browse files
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,205 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIN_BOX_WIDTH = 8 # Минимальная ширина текстовой области (в пикселях)
         | 
| 2 | 
            +
            MIN_BOX_HEIGHT = 15 # Минимальная высота текстовой области (в пикселях)
         | 
| 3 | 
            +
            MAX_PART_WIDTH = 600 # Максимальная ширина части строки (в пикселях)
         | 
| 4 | 
            +
            BOX_HEIGHT_TOLERANCE = 8 # Максимальная разница между высотами текстовых областей для добавлению в строку (в пикселях)
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-printed")
         | 
| 9 | 
            +
            model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-printed")
         | 
| 10 | 
            +
            model.to(device)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            model_translation = T5ForConditionalGeneration.from_pretrained('utrobinmv/t5_translate_en_ru_zh_small_1024')
         | 
| 15 | 
            +
            model_translation.to(device)
         | 
| 16 | 
            +
            tokenizer_translation = T5Tokenizer.from_pretrained('utrobinmv/t5_translate_en_ru_zh_small_1024')
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            def get_text_from_images(images):
         | 
| 19 | 
            +
              extracted_text = []
         | 
| 20 | 
            +
              image_number = 0
         | 
| 21 | 
            +
              for image in images:
         | 
| 22 | 
            +
                image_number += 1
         | 
| 23 | 
            +
                image_cv = np.array(image)
         | 
| 24 | 
            +
                image_cv = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
         | 
| 27 | 
            +
                thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
         | 
| 28 | 
            +
                contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
         | 
| 29 | 
            +
                bounding_boxes = [cv2.boundingRect(contour) for contour in contours]
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def group_boxes_into_lines(boxes, tolerance=BOX_HEIGHT_TOLERANCE):
         | 
| 32 | 
            +
                    sorted_boxes = sorted(boxes, key=lambda box: box[1])
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    lines = []
         | 
| 35 | 
            +
                    current_line = []
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    for box in sorted_boxes:
         | 
| 38 | 
            +
                        x, y, w, h = box
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                        if not current_line:
         | 
| 41 | 
            +
                            current_line.append(box)
         | 
| 42 | 
            +
                        else:
         | 
| 43 | 
            +
                            last_box = current_line[-1]
         | 
| 44 | 
            +
                            last_y = last_box[1]
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                            if abs(y - last_y) <= tolerance:
         | 
| 47 | 
            +
                                current_line.append(box)
         | 
| 48 | 
            +
                            else:
         | 
| 49 | 
            +
                                lines.append(current_line)
         | 
| 50 | 
            +
                                current_line = [box]
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    if current_line:
         | 
| 53 | 
            +
                        lines.append(current_line)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    return lines
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                lines = group_boxes_into_lines(bounding_boxes)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                line_number = 0
         | 
| 60 | 
            +
                for line in lines:
         | 
| 61 | 
            +
                    line_number += 1
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    x_coords = [box[0] for box in line]
         | 
| 64 | 
            +
                    y_coords = [box[1] for box in line]
         | 
| 65 | 
            +
                    widths = [box[2] for box in line]
         | 
| 66 | 
            +
                    heights = [box[3] for box in line]
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    x_min = min(x_coords)
         | 
| 69 | 
            +
                    y_min = min(y_coords)
         | 
| 70 | 
            +
                    x_max = max(x_coords[i] + widths[i] for i in range(len(line)))
         | 
| 71 | 
            +
                    y_max = max(y_coords[i] + heights[i] for i in range(len(line)))
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    line_image = image_cv[y_min:y_max, x_min:x_max]
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    if line_image.size == 0 or line_image.shape[0] < MIN_BOX_HEIGHT or line_image.shape[1] < MIN_BOX_WIDTH:
         | 
| 76 | 
            +
                        continue
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    parts = []
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    if line_image.shape[1] > MAX_PART_WIDTH:
         | 
| 81 | 
            +
                      num_parts = (line_image.shape[1] // MAX_PART_WIDTH) + 1
         | 
| 82 | 
            +
                      part_width = line_image.shape[1] // num_parts
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                      for i in range(num_parts):
         | 
| 85 | 
            +
                        start_x = i * part_width
         | 
| 86 | 
            +
                        end_x = (i + 1) * part_width if i < num_parts - 1 else line_image.shape[1]
         | 
| 87 | 
            +
                        part = line_image[:, start_x:end_x]
         | 
| 88 | 
            +
                        parts.append(part)
         | 
| 89 | 
            +
                    else:
         | 
| 90 | 
            +
                      parts.append(line_image)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    line_text = ""
         | 
| 93 | 
            +
                    part_number = 0
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    for part in parts:
         | 
| 96 | 
            +
                      part_number += 1
         | 
| 97 | 
            +
                      clear_output()
         | 
| 98 | 
            +
                      print(f"Images: {image_number}/{len(images)}")
         | 
| 99 | 
            +
                      print(f"Lines: {line_number}/{len(lines)}")
         | 
| 100 | 
            +
                      print(f"Parts: {part_number}/{len(parts)}")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                      part_image_pil = Image.fromarray(cv2.cvtColor(part, cv2.COLOR_BGR2RGB))
         | 
| 103 | 
            +
                      display(part_image_pil)
         | 
| 104 | 
            +
                      print("\n".join(extracted_text))
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                      pixel_values = processor(part_image_pil, return_tensors="pt").pixel_values
         | 
| 107 | 
            +
                      pixel_values = pixel_values.to(device)
         | 
| 108 | 
            +
                      generated_ids = model.generate(pixel_values)
         | 
| 109 | 
            +
                      text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                      line_text += text
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    extracted_text.append(line_text)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
              final_text = "\n".join(extracted_text)
         | 
| 116 | 
            +
              return final_text
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            def summarize(text, max_length=300, min_length=150):
         | 
| 119 | 
            +
              result = summarizer(text, max_length=max_length, min_length=min_length, do_sample=False)
         | 
| 120 | 
            +
              return result[0]['summary_text']
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            def translate(text):
         | 
| 123 | 
            +
              prefix = 'translate to ru: '
         | 
| 124 | 
            +
              src_text = prefix + text
         | 
| 125 | 
            +
             | 
| 126 | 
            +
              input_ids = tokenizer_translation(src_text, return_tensors="pt")
         | 
| 127 | 
            +
             | 
| 128 | 
            +
              generated_tokens = model_translation.generate(**input_ids.to(device))
         | 
| 129 | 
            +
             | 
| 130 | 
            +
              result = tokenizer_translation.batch_decode(generated_tokens, skip_special_tokens=True)
         | 
| 131 | 
            +
              return result[0]
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            def launch(images, language):
         | 
| 134 | 
            +
              if images == None or not images:
         | 
| 135 | 
            +
                return "No input provided."
         | 
| 136 | 
            +
              raw_text = get_text_from_images(images)
         | 
| 137 | 
            +
              summary = summarize(raw_text)
         | 
| 138 | 
            +
              if language == "rus":
         | 
| 139 | 
            +
                return translate(summary)
         | 
| 140 | 
            +
              return summary
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            def pdf_to_image(pdf, index = 0):
         | 
| 143 | 
            +
              images = convert_from_bytes(pdf)
         | 
| 144 | 
            +
              if 0 <= index < len(images):
         | 
| 145 | 
            +
                return [images[index]]
         | 
| 146 | 
            +
              return []
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            def pdf_to_images(pdf):
         | 
| 149 | 
            +
              images = convert_from_bytes(pdf)
         | 
| 150 | 
            +
              return images
         | 
| 151 | 
            +
             | 
| 152 | 
            +
            def process_pdf(pdf_file, process_mode, page_index, language): 
         | 
| 153 | 
            +
              if process_mode == "all":
         | 
| 154 | 
            +
                return launch(pdf_to_images(pdf_file), language)
         | 
| 155 | 
            +
              elif process_mode == "single":
         | 
| 156 | 
            +
                return launch(pdf_to_image(pdf_file, page_index), language)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
            def process_images(images, language):
         | 
| 159 | 
            +
              pil_images = []
         | 
| 160 | 
            +
              for image in images:
         | 
| 161 | 
            +
                pil_images.append(Image.open(image))
         | 
| 162 | 
            +
              launch(pil_images, language)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
            class PrintToTextbox:
         | 
| 165 | 
            +
                def __init__(self, textbox):
         | 
| 166 | 
            +
                    self.textbox = textbox
         | 
| 167 | 
            +
                    self.buffer = ""
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                def write(self, text):
         | 
| 170 | 
            +
                    self.buffer += text
         | 
| 171 | 
            +
                    self.textbox.update(self.buffer)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def flush(self):
         | 
| 174 | 
            +
                    pass
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    def update_page_index_visibility(process_mode):
         | 
| 177 | 
            +
                if process_mode == "single":
         | 
| 178 | 
            +
                    return gr.update(visible=True)
         | 
| 179 | 
            +
                else:
         | 
| 180 | 
            +
                    return gr.update(visible=False)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
            with gr.Blocks() as demo:
         | 
| 183 | 
            +
                gr.Markdown("# PDF and Image Text Summarizer")
         | 
| 184 | 
            +
                gr.Markdown("Upload a PDF file or images to extract and summarize text.")
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                language = gr.Radio(choices=["rus", "eng"], label="Select Language", value="rus")
         | 
| 187 | 
            +
                
         | 
| 188 | 
            +
                with gr.Tabs():
         | 
| 189 | 
            +
                    with gr.TabItem("PDF"):
         | 
| 190 | 
            +
                        pdf_file = gr.File(label="Upload PDF File", type="binary")
         | 
| 191 | 
            +
                        process_mode = gr.Radio(choices=["all", "single"], label="Process Mode", value="all")
         | 
| 192 | 
            +
                        page_index = gr.Number(label="Page Index", value=0, precision=0, visible=False)
         | 
| 193 | 
            +
                        pdf_output = gr.Textbox(label="Extracted Text")
         | 
| 194 | 
            +
                        pdf_button = gr.Button("Extract Text from PDF")
         | 
| 195 | 
            +
                    
         | 
| 196 | 
            +
                    with gr.TabItem("Images"):
         | 
| 197 | 
            +
                        images = gr.Files(label="Upload Images", file_types=["image"])
         | 
| 198 | 
            +
                        image_output = gr.Textbox(label="Extracted Text")
         | 
| 199 | 
            +
                        image_button = gr.Button("Extract Text from Images")
         | 
| 200 | 
            +
                
         | 
| 201 | 
            +
                pdf_button.click(process_pdf, inputs=[pdf_file, process_mode, page_index, language], outputs=pdf_output)
         | 
| 202 | 
            +
                image_button.click(process_images, inputs=[images, language], outputs=image_output)
         | 
| 203 | 
            +
                process_mode.change(update_page_index_visibility, inputs=process_mode, outputs=page_index)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
            demo.launch(debug=True)
         | 
