doc-to-slides / app.py
com3dian's picture
Update app.py
45e4e41 verified
raw
history blame
6.88 kB
import streamlit as st
import pandas as pd
import numpy as np
import os
import pickle
import torch
import markdown
import pdfkit
import io
from grobidmonkey import reader
from transformers import pipeline
from transformers import BartTokenizer, BartModel, BartForConditionalGeneration
from transformers import T5Tokenizer, T5ForConditionalGeneration
from document import Document
from BartSE import BARTAutoEncoder
def save_uploaded_file(uploaded_file):
file_path = os.path.join("./uploads", uploaded_file.name)
os.makedirs("./uploads", exist_ok=True) # Create 'uploads' directory if it doesn't exist
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
return file_path # Return the file path as a string
st.title('Paper2Slides')
st.subheader('Upload paper in pdf format')
# col1, col2 = st.columns([3, 1])
# with col1:
# uploaded_file = st.file_uploader("Choose a file")
# with col2:
# option = st.selectbox(
# 'Select parsing method.',
# ('monkey', 'x2d', 'lxml'))
# if uploaded_file is not None:
# st.write(uploaded_file.name)
# bytes_data = uploaded_file.getvalue()
# st.write(len(bytes_data), "bytes")
# saved_file_path = save_uploaded_file(uploaded_file)
# monkeyReader = reader.MonkeyReader(option)
# outline = monkeyReader.readOutline(saved_file_path)
# for pre, fill, node in outline:
# st.write("%s%s" % (pre, node.name))
# # read paper content
# essay = monkeyReader.readEssay(saved_file_path)
# with st.status("Understanding paper..."):
# Barttokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
# summ_model_path = 'com3dian/Bart-large-paper2slides-summarizer'
# summarizor = BartForConditionalGeneration.from_pretrained(summ_model_path)
# exp_model_path = 'com3dian/Bart-large-paper2slides-expander'
# expandor = BartForConditionalGeneration.from_pretrained(exp_model_path)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# BartSE = BARTAutoEncoder(summarizor, summarizor, device)
# del summarizor, expandor
# document = Document(essay, Barttokenizer)
# del Barttokenizer
# length = document.merge(25, 30, BartSE, device)
# with st.status("Generating slides..."):
# summarizor = pipeline("summarization", model=summ_model_path, device = device)
# summ_text = summarizor(document.segmentation['text'], max_length=100, min_length=10, do_sample=False)
# summ_text = [text['summary_text'] for text in summ_text]
# for summ in summ_text:
# st.write(summ)
with open('slides_text.pkl', 'rb') as file:
summ_text = pickle.load(file)
# Function to render HTML content
def format(text_list):
format_list = []
for text in text_list:
# Split text by periods
sentences = text.split('.')
# Create HTML list items
list_items = "".join([f"- {sentence.strip()}.\n" for sentence in sentences if sentence.strip()])
format_list.append(list_items)
return format_list
# Initialize session state for page index and text
if 'page_index' not in st.session_state:
st.session_state.page_index = 0
if 'summ_text' not in st.session_state:
st.session_state.summ_text = format(summ_text)
if 'current_text' not in st.session_state:
st.session_state.current_text = st.session_state.summ_text[st.session_state.page_index]
# Function to handle page turn
def turn_page(direction):
if direction == "next" and st.session_state.page_index < len(summ_text) - 1:
st.session_state.page_index += 1
elif direction == "prev" and st.session_state.page_index > 0:
st.session_state.page_index -= 1
st.session_state.current_text = st.session_state.summ_text[st.session_state.page_index]
# Function to update the current text based on text_area changes
def update_text():
st.session_state.summ_text[st.session_state.page_index] = st.session_state.text_area_value
st.session_state.current_text = st.session_state.text_area_value
# Display editable text box
text = st.text_area("Edit Text", st.session_state.current_text, height=200, key="text_area_value", on_change=update_text)
# Display page turner controls
col1, col2, col3 = st.columns([2.25, 12, 1.7])
# Previous button in col1
with col1:
st.button("Previous", on_click=turn_page, args=("prev",))
# Center aligned text in col2
with col2:
st.markdown(
f'<div style="display: flex; justify-content: center; align-items: center; height: 100%;">'
f'Page {st.session_state.page_index + 1} of {len(summ_text)}'
f'</div>',
unsafe_allow_html=True
)
# Next button in col3, right aligned
with col3:
st.button("Next", on_click=turn_page, args=("next",))
# Display HTML box
st.markdown(st.session_state.current_text)
def render_markdown_to_html(markdown_str):
return markdown.markdown(markdown_str)
def create_pdf_from_markdown_strings(markdown_strings):
html_pages = [render_markdown_to_html(md) for md in markdown_strings]
# Combine HTML content with page breaks and add a style section for font size and margins
combined_html = '''
<html>
<head>
<style>
.page {
font-size: 16pt; /* Adjust the font size as needed */
margin: 20mm; /* Set margins for top, right, bottom, and left */
}
</style>
</head>
<body>
'''
for i, page in enumerate(html_pages):
combined_html += f'<div class="page">{page}</div>'
if i < len(html_pages) - 1: # Only add page break after if it's not the last page
combined_html += '<div style="page-break-after: always;"></div>'
combined_html += '</body></html>'
# PDF options: landscape orientation and page size
options = {
'page-size': 'A4',
'orientation': 'Landscape'
}
return html_pages
def html_to_pdf(html_content):
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
pdfkit.from_string(html_content, tmp_file.name)
return tmp_file.name
# Convert combined HTML to PDF directly into a file
pdfkit.from_string(combined_html, output_file, options=options)
html_content = create_pdf_from_markdown_strings(st.session_state.summ_text)
if st.button("Download PDF"):
if html_content:
pdf_path = html_to_pdf(html_content)
with open(pdf_path, "rb") as pdf_file:
st.download_button(
label="Download PDF",
data=pdf_file,
file_name="converted.pdf",
mime="application/pdf"
)
# Remove the temporary file
os.remove(pdf_path)