Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
from typing import List | |
import gradio as gr | |
from PIL import Image | |
# PDF parsing | |
from pypdf import PdfReader | |
from unstructured.partition.pdf import partition_pdf | |
from unstructured.partition.utils.constants import PartitionStrategy | |
# Text splitting | |
from langchain.text_splitter import CharacterTextSplitter | |
# Vectorstore and embeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
# Vision-language captioning (BLIP) | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
# LLM via HF Inference API | |
from huggingface_hub import InferenceClient | |
# ββ Globals βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
retriever = None | |
pdf_text: str = "" | |
# ββ Setup directories ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
FIGURES_DIR = "figures" | |
if os.path.exists(FIGURES_DIR): | |
shutil.rmtree(FIGURES_DIR) | |
os.makedirs(FIGURES_DIR, exist_ok=True) | |
# ββ Models & Clients βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
hf_client = InferenceClient() # uses HUGGINGFACEHUB_API_TOKEN | |
# Embeddings model (local lightweight SBERT) | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# BLIP for image captioning | |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
# ββ Helper functions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def generate_caption(image_path: str) -> str: | |
image = Image.open(image_path).convert("RGB") | |
inputs = blip_processor(image, return_tensors="pt") | |
outputs = blip_model.generate(**inputs) | |
return blip_processor.decode(outputs[0], skip_special_tokens=True) | |
def process_pdf(pdf_file): | |
global retriever, pdf_text | |
if pdf_file is None: | |
return None, "β Please upload a PDF.", gr.update(interactive=False) | |
# read full text | |
reader = PdfReader(pdf_file.name) | |
pages = [p.extract_text() or "" for p in reader.pages] | |
pdf_text = " | |
".join(pages) | |
# extract elements with images via unstructured | |
try: | |
elements = partition_pdf( | |
filename=pdf_file.name, | |
strategy=PartitionStrategy.HI_RES, | |
extract_image_block_types=["Image", "Table"], | |
extract_image_block_output_dir=FIGURES_DIR, | |
) | |
text_elems = [e.text for e in elements if e.category not in ["Image","Table"] and e.text] | |
image_files = [os.path.join(FIGURES_DIR, f) for f in os.listdir(FIGURES_DIR) | |
if f.lower().endswith((".png",".jpg",".jpeg"))] | |
except: | |
text_elems = pages | |
image_files = [] | |
# generate captions | |
captions = [generate_caption(img) for img in image_files] | |
# split text into chunks | |
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
chunks = [] | |
for t in text_elems: | |
chunks.extend(splitter.split_text(t)) | |
# combine text chunks and image captions | |
docs = chunks + captions | |
# embed and index | |
vectors = embeddings.embed_documents(docs) | |
pairs = list(zip(docs, vectors)) | |
index = FAISS.from_embeddings(pairs) | |
retriever = index.as_retriever(search_kwargs={"k": 2}) | |
status = f"β Indexed β {len(chunks)} text chunks + {len(captions)} captions" | |
return os.path.basename(pdf_file.name), status, gr.update(interactive=True) | |
def ask_question(pdf_name, question): | |
if retriever is None: | |
return "β Please upload + index a PDF first." | |
if not question: | |
return "β Please ask something." | |
docs = retriever.get_relevant_documents(question) | |
context = "\n\n".join(d.page_content for d in docs) | |
prompt = f"Use the following excerpts to answer:\n{context}\nQuestion: {question}\nAnswer:" | |
res = hf_client.chat_completion( | |
model="google/gemma-3-27b-it", | |
messages=[{"role":"user","content":prompt}], | |
max_tokens=128, | |
temperature=0.5, | |
) | |
return res["choices"][0]["message"]["content"].strip() | |
def generate_summary(): | |
if not pdf_text: | |
return "β Please index a PDF first." | |
return ask_question(None, f"Summarize concisely:\n{pdf_text[:2000]}") | |
def extract_keywords(): | |
if not pdf_text: | |
return "β Please index first." | |
return ask_question(None, f"Extract 10β15 key terms:\n{pdf_text[:2000]}") | |
def clear_all(): | |
global retriever, pdf_text | |
retriever = None | |
pdf_text = "" | |
shutil.rmtree(FIGURES_DIR, ignore_errors=True) | |
os.makedirs(FIGURES_DIR, exist_ok=True) | |
return None, "", gr.update(interactive=False) | |
# ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="blue") | |
with gr.Blocks(theme=theme) as demo: | |
gr.Markdown("# Multimodal RAG with HF & LangChain") | |
with gr.Row(): | |
with gr.Column(): | |
pdf_disp = gr.Textbox(label="Active PDF", interactive=False) | |
pdf_file = gr.File(label="Upload PDF", type="filepath", file_types=[".pdf"]) | |
btn_proc = gr.Button("π Process PDF") | |
status = gr.Textbox(label="Status", interactive=False) | |
with gr.Column(): | |
q_in = gr.Textbox(label="Your question", interactive=False) | |
btn_ask = gr.Button("β Ask", interactive=False) | |
ans = gr.Textbox(label="Answer", interactive=False) | |
with gr.Row(): | |
btn_sum = gr.Button("π Summary", interactive=False) | |
sum_out = gr.Textbox(interactive=False) | |
btn_key = gr.Button("π·οΈ Keywords", interactive=False) | |
key_out = gr.Textbox(interactive=False) | |
btn_clear = gr.Button("ποΈ Clear All") | |
btn_proc.click(process_pdf, [pdf_file], [pdf_disp, status, q_in]) | |
btn_ask.click(ask_question, [pdf_disp, q_in], ans) | |
btn_sum.click(generate_summary, [], sum_out) | |
btn_key.click(extract_keywords, [], key_out) | |
btn_clear.click(clear_all, [], [pdf_disp, status, q_in]) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |