ikraamkb commited on
Commit
c3071ac
Β·
verified Β·
1 Parent(s): 5e30a65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -29
app.py CHANGED
@@ -1,36 +1,38 @@
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.responses import RedirectResponse
3
  import gradio as gr
4
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
5
  import tempfile
6
  import os
7
  from PIL import Image
8
  import fitz # PyMuPDF
9
  import docx
10
- import openpyxl
11
- from pptx import Presentation
12
  import easyocr
13
 
14
  app = FastAPI()
15
 
16
- # Initialize models with error handling
 
 
 
 
17
  try:
18
- # Load summarization model directly with tokenizer
19
- tokenizer = AutoTokenizer.from_pretrained("FeruzaBoynazarovaas/my_awesome_billsum_model", use_fast=False)
20
- model = AutoModelForSeq2SeqLM.from_pretrained("FeruzaBoynazarovaas/my_awesome_billsum_model")
21
  summarizer = pipeline(
22
- "text2text-generation",
23
- model=model,
24
- tokenizer=tokenizer
25
  )
26
  except Exception as e:
27
  print(f"Error loading summarizer: {e}")
28
- # Fallback to a default model if custom fails
29
- summarizer = pipeline("text2text-generation", model="t5-small")
 
 
 
 
 
30
 
31
- # Other models (these should work fine)
32
- captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
33
- reader = easyocr.Reader(['en'])
34
 
35
  def extract_text_from_file(file_path: str, file_type: str):
36
  """Extract text from different document formats"""
@@ -41,26 +43,24 @@ def extract_text_from_file(file_path: str, file_type: str):
41
  elif file_type == "docx":
42
  doc = docx.Document(file_path)
43
  return "\n".join(p.text for p in doc.paragraphs)
44
- elif file_type == "pptx":
45
- prs = Presentation(file_path)
46
- return "\n".join(shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text"))
47
- elif file_type == "xlsx":
48
- wb = openpyxl.load_workbook(file_path)
49
- return "\n".join(str(cell.value) for sheet in wb for row in sheet for cell in row)
50
  else:
51
- return "Unsupported file format"
52
  except Exception as e:
53
  return f"Error reading file: {str(e)}"
54
 
55
  def process_document(file):
 
56
  try:
57
  file_ext = os.path.splitext(file.name)[1][1:].lower()
 
 
 
58
  with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_ext}") as tmp:
59
  tmp.write(file.read())
60
  tmp_path = tmp.name
61
 
62
  text = extract_text_from_file(tmp_path, file_ext)
63
- summary = summarizer(text, max_length=150, min_length=30, do_sample=False)[0]['generated_text']
64
 
65
  os.unlink(tmp_path)
66
  return summary
@@ -68,11 +68,17 @@ def process_document(file):
68
  return f"Processing error: {str(e)}"
69
 
70
  def process_image(image):
 
71
  try:
72
  img = Image.open(image)
 
 
73
  caption = captioner(img)[0]['generated_text']
 
 
74
  ocr_result = reader.readtext(img)
75
  ocr_text = " ".join([res[1] for res in ocr_result])
 
76
  return {
77
  "caption": caption,
78
  "ocr_text": ocr_text if ocr_text else "No readable text found"
@@ -81,25 +87,29 @@ def process_image(image):
81
  return {"error": str(e)}
82
 
83
  # Gradio Interface
84
- with gr.Blocks() as demo:
85
- gr.Markdown("# πŸ“„ Document & Image Analysis")
86
 
87
  with gr.Tab("Document Summarization"):
88
- doc_input = gr.File(label="Upload Document")
 
89
  doc_output = gr.Textbox(label="Summary")
90
  doc_button = gr.Button("Summarize")
91
 
92
  with gr.Tab("Image Analysis"):
 
93
  img_input = gr.Image(type="filepath", label="Upload Image")
94
- caption_output = gr.Textbox(label="Image Caption")
95
- ocr_output = gr.Textbox(label="Extracted Text")
 
96
  img_button = gr.Button("Analyze")
97
 
98
  doc_button.click(process_document, inputs=doc_input, outputs=doc_output)
99
  img_button.click(process_image, inputs=img_input, outputs=[caption_output, ocr_output])
100
 
 
101
  app = gr.mount_gradio_app(app, demo, path="/")
102
 
103
  @app.get("/")
104
- def redirect():
105
  return RedirectResponse(url="/")
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.responses import RedirectResponse
3
  import gradio as gr
4
+ from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
5
  import tempfile
6
  import os
7
  from PIL import Image
8
  import fitz # PyMuPDF
9
  import docx
 
 
10
  import easyocr
11
 
12
  app = FastAPI()
13
 
14
+ # Lightweight model choices
15
+ SUMMARIZATION_MODEL = "facebook/bart-large-cnn" # 500MB
16
+ IMAGE_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-base" # 300MB
17
+
18
+ # Initialize models
19
  try:
 
 
 
20
  summarizer = pipeline(
21
+ "summarization",
22
+ model=SUMMARIZATION_MODEL,
23
+ device="cpu"
24
  )
25
  except Exception as e:
26
  print(f"Error loading summarizer: {e}")
27
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6") # Fallback 250MB model
28
+
29
+ captioner = pipeline(
30
+ "image-to-text",
31
+ model=IMAGE_CAPTIONING_MODEL,
32
+ device="cpu"
33
+ )
34
 
35
+ reader = easyocr.Reader(['en']) # Lightweight OCR
 
 
36
 
37
  def extract_text_from_file(file_path: str, file_type: str):
38
  """Extract text from different document formats"""
 
43
  elif file_type == "docx":
44
  doc = docx.Document(file_path)
45
  return "\n".join(p.text for p in doc.paragraphs)
 
 
 
 
 
 
46
  else:
47
+ return "Unsupported file format (only PDF/DOCX supported in lightweight version)"
48
  except Exception as e:
49
  return f"Error reading file: {str(e)}"
50
 
51
  def process_document(file):
52
+ """Handle document summarization"""
53
  try:
54
  file_ext = os.path.splitext(file.name)[1][1:].lower()
55
+ if file_ext not in ["pdf", "docx"]:
56
+ return "Lightweight version only supports PDF and DOCX"
57
+
58
  with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_ext}") as tmp:
59
  tmp.write(file.read())
60
  tmp_path = tmp.name
61
 
62
  text = extract_text_from_file(tmp_path, file_ext)
63
+ summary = summarizer(text, max_length=130, min_length=30, do_sample=False)[0]['summary_text']
64
 
65
  os.unlink(tmp_path)
66
  return summary
 
68
  return f"Processing error: {str(e)}"
69
 
70
  def process_image(image):
71
+ """Handle image captioning and OCR"""
72
  try:
73
  img = Image.open(image)
74
+
75
+ # Get caption
76
  caption = captioner(img)[0]['generated_text']
77
+
78
+ # Get OCR text
79
  ocr_result = reader.readtext(img)
80
  ocr_text = " ".join([res[1] for res in ocr_result])
81
+
82
  return {
83
  "caption": caption,
84
  "ocr_text": ocr_text if ocr_text else "No readable text found"
 
87
  return {"error": str(e)}
88
 
89
  # Gradio Interface
90
+ with gr.Blocks(title="Lightweight Document & Image Analysis") as demo:
91
+ gr.Markdown("## πŸ“„ Lightweight Document & Image Analysis")
92
 
93
  with gr.Tab("Document Summarization"):
94
+ gr.Markdown("Supports PDF and DOCX files (max 10MB)")
95
+ doc_input = gr.File(label="Upload Document", file_types=[".pdf", ".docx"])
96
  doc_output = gr.Textbox(label="Summary")
97
  doc_button = gr.Button("Summarize")
98
 
99
  with gr.Tab("Image Analysis"):
100
+ gr.Markdown("Get captions and extracted text from images")
101
  img_input = gr.Image(type="filepath", label="Upload Image")
102
+ with gr.Accordion("Results", open=False):
103
+ caption_output = gr.Textbox(label="Image Caption")
104
+ ocr_output = gr.Textbox(label="Extracted Text")
105
  img_button = gr.Button("Analyze")
106
 
107
  doc_button.click(process_document, inputs=doc_input, outputs=doc_output)
108
  img_button.click(process_image, inputs=img_input, outputs=[caption_output, ocr_output])
109
 
110
+ # Mount Gradio app
111
  app = gr.mount_gradio_app(app, demo, path="/")
112
 
113
  @app.get("/")
114
+ def redirect_to_interface():
115
  return RedirectResponse(url="/")