Multimodal / app.py
Muzammil6376's picture
Update app.py
b6b04c5 verified
raw
history blame
6.76 kB
import os
import shutil
from typing import List
import gradio as gr
from PIL import Image
# PDF parsing
from pypdf import PdfReader
from unstructured.partition.pdf import partition_pdf
from unstructured.partition.utils.constants import PartitionStrategy
# Text splitting
from langchain.text_splitter import CharacterTextSplitter
# Vectorstore and embeddings
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
# Vision-language captioning (BLIP)
from transformers import BlipProcessor, BlipForConditionalGeneration
# LLM via HF Inference API
from huggingface_hub import InferenceClient
# ── Globals ───────────────────────────────────────────────────────────────────
retriever = None
pdf_text: str = ""
# ── Setup directories ──────────────────────────────────────────────────────────
FIGURES_DIR = "figures"
if os.path.exists(FIGURES_DIR):
shutil.rmtree(FIGURES_DIR)
os.makedirs(FIGURES_DIR, exist_ok=True)
# ── Models & Clients ───────────────────────────────────────────────────────────
hf_client = InferenceClient() # uses HUGGINGFACEHUB_API_TOKEN
# Embeddings model (local lightweight SBERT)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# BLIP for image captioning
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# ── Helper functions ───────────────────────────────────────────────────────────
def generate_caption(image_path: str) -> str:
image = Image.open(image_path).convert("RGB")
inputs = blip_processor(image, return_tensors="pt")
outputs = blip_model.generate(**inputs)
return blip_processor.decode(outputs[0], skip_special_tokens=True)
def process_pdf(pdf_file):
global retriever, pdf_text
if pdf_file is None:
return None, "❌ Please upload a PDF.", gr.update(interactive=False)
# read full text
reader = PdfReader(pdf_file.name)
pages = [p.extract_text() or "" for p in reader.pages]
pdf_text = "
".join(pages)
# extract elements with images via unstructured
try:
elements = partition_pdf(
filename=pdf_file.name,
strategy=PartitionStrategy.HI_RES,
extract_image_block_types=["Image", "Table"],
extract_image_block_output_dir=FIGURES_DIR,
)
text_elems = [e.text for e in elements if e.category not in ["Image","Table"] and e.text]
image_files = [os.path.join(FIGURES_DIR, f) for f in os.listdir(FIGURES_DIR)
if f.lower().endswith((".png",".jpg",".jpeg"))]
except:
text_elems = pages
image_files = []
# generate captions
captions = [generate_caption(img) for img in image_files]
# split text into chunks
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = []
for t in text_elems:
chunks.extend(splitter.split_text(t))
# combine text chunks and image captions
docs = chunks + captions
# embed and index
vectors = embeddings.embed_documents(docs)
pairs = list(zip(docs, vectors))
index = FAISS.from_embeddings(pairs)
retriever = index.as_retriever(search_kwargs={"k": 2})
status = f"βœ… Indexed β€” {len(chunks)} text chunks + {len(captions)} captions"
return os.path.basename(pdf_file.name), status, gr.update(interactive=True)
def ask_question(pdf_name, question):
if retriever is None:
return "❌ Please upload + index a PDF first."
if not question:
return "❌ Please ask something."
docs = retriever.get_relevant_documents(question)
context = "\n\n".join(d.page_content for d in docs)
prompt = f"Use the following excerpts to answer:\n{context}\nQuestion: {question}\nAnswer:"
res = hf_client.chat_completion(
model="google/gemma-3-27b-it",
messages=[{"role":"user","content":prompt}],
max_tokens=128,
temperature=0.5,
)
return res["choices"][0]["message"]["content"].strip()
def generate_summary():
if not pdf_text:
return "❌ Please index a PDF first."
return ask_question(None, f"Summarize concisely:\n{pdf_text[:2000]}")
def extract_keywords():
if not pdf_text:
return "❌ Please index first."
return ask_question(None, f"Extract 10–15 key terms:\n{pdf_text[:2000]}")
def clear_all():
global retriever, pdf_text
retriever = None
pdf_text = ""
shutil.rmtree(FIGURES_DIR, ignore_errors=True)
os.makedirs(FIGURES_DIR, exist_ok=True)
return None, "", gr.update(interactive=False)
# ── Gradio UI ────────────────────────────────────────────────────────────────
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="blue")
with gr.Blocks(theme=theme) as demo:
gr.Markdown("# Multimodal RAG with HF & LangChain")
with gr.Row():
with gr.Column():
pdf_disp = gr.Textbox(label="Active PDF", interactive=False)
pdf_file = gr.File(label="Upload PDF", type="filepath", file_types=[".pdf"])
btn_proc = gr.Button("πŸ“„ Process PDF")
status = gr.Textbox(label="Status", interactive=False)
with gr.Column():
q_in = gr.Textbox(label="Your question", interactive=False)
btn_ask = gr.Button("❓ Ask", interactive=False)
ans = gr.Textbox(label="Answer", interactive=False)
with gr.Row():
btn_sum = gr.Button("πŸ“‹ Summary", interactive=False)
sum_out = gr.Textbox(interactive=False)
btn_key = gr.Button("🏷️ Keywords", interactive=False)
key_out = gr.Textbox(interactive=False)
btn_clear = gr.Button("πŸ—‘οΈ Clear All")
btn_proc.click(process_pdf, [pdf_file], [pdf_disp, status, q_in])
btn_ask.click(ask_question, [pdf_disp, q_in], ans)
btn_sum.click(generate_summary, [], sum_out)
btn_key.click(extract_keywords, [], key_out)
btn_clear.click(clear_all, [], [pdf_disp, status, q_in])
if __name__ == "__main__":
demo.launch(debug=True)