Summarization / app.py
ikraamkb's picture
Update app.py
5e30a65 verified
raw
history blame
3.85 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import RedirectResponse
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import tempfile
import os
from PIL import Image
import fitz # PyMuPDF
import docx
import openpyxl
from pptx import Presentation
import easyocr
app = FastAPI()
# Initialize models with error handling
try:
# Load summarization model directly with tokenizer
tokenizer = AutoTokenizer.from_pretrained("FeruzaBoynazarovaas/my_awesome_billsum_model", use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained("FeruzaBoynazarovaas/my_awesome_billsum_model")
summarizer = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer
)
except Exception as e:
print(f"Error loading summarizer: {e}")
# Fallback to a default model if custom fails
summarizer = pipeline("text2text-generation", model="t5-small")
# Other models (these should work fine)
captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
reader = easyocr.Reader(['en'])
def extract_text_from_file(file_path: str, file_type: str):
"""Extract text from different document formats"""
try:
if file_type == "pdf":
with fitz.open(file_path) as doc:
return "\n".join(page.get_text() for page in doc)
elif file_type == "docx":
doc = docx.Document(file_path)
return "\n".join(p.text for p in doc.paragraphs)
elif file_type == "pptx":
prs = Presentation(file_path)
return "\n".join(shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text"))
elif file_type == "xlsx":
wb = openpyxl.load_workbook(file_path)
return "\n".join(str(cell.value) for sheet in wb for row in sheet for cell in row)
else:
return "Unsupported file format"
except Exception as e:
return f"Error reading file: {str(e)}"
def process_document(file):
try:
file_ext = os.path.splitext(file.name)[1][1:].lower()
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_ext}") as tmp:
tmp.write(file.read())
tmp_path = tmp.name
text = extract_text_from_file(tmp_path, file_ext)
summary = summarizer(text, max_length=150, min_length=30, do_sample=False)[0]['generated_text']
os.unlink(tmp_path)
return summary
except Exception as e:
return f"Processing error: {str(e)}"
def process_image(image):
try:
img = Image.open(image)
caption = captioner(img)[0]['generated_text']
ocr_result = reader.readtext(img)
ocr_text = " ".join([res[1] for res in ocr_result])
return {
"caption": caption,
"ocr_text": ocr_text if ocr_text else "No readable text found"
}
except Exception as e:
return {"error": str(e)}
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# πŸ“„ Document & Image Analysis")
with gr.Tab("Document Summarization"):
doc_input = gr.File(label="Upload Document")
doc_output = gr.Textbox(label="Summary")
doc_button = gr.Button("Summarize")
with gr.Tab("Image Analysis"):
img_input = gr.Image(type="filepath", label="Upload Image")
caption_output = gr.Textbox(label="Image Caption")
ocr_output = gr.Textbox(label="Extracted Text")
img_button = gr.Button("Analyze")
doc_button.click(process_document, inputs=doc_input, outputs=doc_output)
img_button.click(process_image, inputs=img_input, outputs=[caption_output, ocr_output])
app = gr.mount_gradio_app(app, demo, path="/")
@app.get("/")
def redirect():
return RedirectResponse(url="/")