File size: 2,784 Bytes
66a005f
f0a8738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1e4d14
f0a8738
6e82059
f0a8738
 
 
 
 
 
 
 
f262292
f0a8738
 
 
 
 
 
 
 
6e82059
66a005f
6e82059
 
f0a8738
 
 
 
 
 
 
 
6e82059
f0a8738
 
 
 
 
 
c35c0c4
 
f0a8738
 
 
 
 
 
66a005f
b85773b
f0a8738
 
 
b85773b
 
 
f0a8738
 
 
6523597
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from src.text_extractor import TextExtractor
from tqdm import tqdm
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from transformers import pipeline
from mdutils.mdutils import MdUtils
from pathlib import Path

import gradio as gr
import fitz
import torch
import copy
import os

FILENAME = ""

preprocess = TextExtractor()
model_name = "sshleifer/distill-pegasus-cnn-16-4"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)

def summarize(slides):
    generated_slides = copy.deepcopy(slides)
    for page, contents in tqdm(generated_slides.items()):
        for idx, (tag, content) in enumerate(contents):
            if tag.startswith('p'): 
                try:
                    input = tokenizer(content, truncation=True, padding="longest", return_tensors="pt").to(device)
                    tensor = model.generate(**input)
                    summary = tokenizer.batch_decode(tensor, skip_special_tokens=True)[0]
                    contents[idx] = (tag, summary)
                except Exception as e:
                    print(e)
                    print("Summarization Fails")
    return generated_slides

def convert2markdown(generated_slides):
    mdFile = MdUtils(file_name=FILENAME, title=f'{FILENAME} Presentation')
    for k, v in generated_slides.items():
        mdFile.new_line('---\n')
        for section in v:
            tag = section[0]
            content = section[1]
            if tag.startswith('h'):
                mdFile.new_header(level=int(tag[1]), title=content)
            if tag == 'p':
                contents = content.split('<n>')
                for content in contents:
                    mdFile.new_line(f"{content}\n")
    mdFile.create_md_file()
    return f"{FILENAME}.md"

def inference(document):
    global FILENAME
    doc = fitz.open(document)
    FILENAME = document.name.split('/')[-1].split('.')[0]
    print(f"FILENAME: {FILENAME}")
    font_counts, styles = preprocess.get_font_info(doc, granularity=False)
    size_tag = preprocess.get_font_tags(font_counts, styles)
    texts = preprocess.assign_tags(doc, size_tag)
    slides = preprocess.get_slides(texts)
    generated_slides = summarize(slides)
    markdown_path = convert2markdown(generated_slides)
    print(f"Markdown Path: {markdown_path}")
    return markdown_path


with gr.Blocks() as demo:
    inp = gr.File(file_types=['pdf'])
    out = gr.File(label="Markdown File")
    # out = gr.Textbox(label="Markdown Content")
    inference_btn = gr.Button("Summarized PDF")
    inference_btn.click(fn=inference, inputs=inp, outputs=out, show_progress=True, api_name="summarize")
    
demo.launch()