Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
from typing import List | |
import torch | |
import gradio as gr | |
from PIL import Image | |
# Unstructured for PDF parsing | |
from unstructured.partition.pdf import partition_pdf | |
from unstructured.partition.utils.constants import PartitionStrategy | |
# Vision-language captioning (BLIP) | |
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel | |
# Hugging Face Inference client for LLM | |
from huggingface_hub import InferenceClient | |
# FAISS vectorstore | |
from langchain_community.vectorstores import FAISS | |
# Text embeddings | |
from langchain_huggingface import HuggingFaceEmbeddings | |
# ββ Globals βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
retriever = None | |
current_pdf_name = None | |
combined_texts: List[str] = [] # text chunks + captions | |
combined_vectors: List[List[float]] = [] | |
pdf_text: str = "" | |
# ββ Setup βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
FIGURES_DIR = "figures" | |
if os.path.exists(FIGURES_DIR): | |
shutil.rmtree(FIGURES_DIR) | |
else: | |
os.makedirs(FIGURES_DIR, exist_ok=True) | |
# ββ Clients & Models βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
hf = InferenceClient() # for chat completions | |
txt_emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
def generate_caption(image_path: str) -> str: | |
image = Image.open(image_path).convert("RGB") | |
inputs = blip_processor(image, return_tensors="pt") | |
out = blip_model.generate(**inputs) | |
return blip_processor.decode(out[0], skip_special_tokens=True) | |
def embed_texts(texts: List[str]) -> List[List[float]]: | |
return txt_emb.embed_documents(texts) | |
def embed_images(image_paths: List[str]) -> List[List[float]]: | |
feats = [] | |
for p in image_paths: | |
img = Image.open(p).convert("RGB") | |
inputs = clip_processor(images=img, return_tensors="pt") | |
with torch.no_grad(): | |
v = clip_model.get_image_features(**inputs) | |
feats.append(v[0].cpu().tolist()) | |
return feats | |
def process_pdf(pdf_file): | |
global retriever, current_pdf_name, combined_texts, combined_vectors, pdf_text | |
if pdf_file is None: | |
return None, "β Please upload a PDF file.", gr.update(interactive=False) | |
current_pdf_name = os.path.basename(pdf_file.name) | |
# extract full text | |
from pypdf import PdfReader | |
reader = PdfReader(pdf_file.name) | |
pages = [page.extract_text() or "" for page in reader.pages] | |
pdf_text = "\n\n".join(pages) | |
# rich parsing for images | |
try: | |
els = partition_pdf( | |
filename=pdf_file.name, | |
strategy=PartitionStrategy.HI_RES, | |
extract_image_block_types=["Image","Table"], | |
extract_image_block_output_dir=FIGURES_DIR, | |
) | |
texts = [e.text for e in els if e.category not in ["Image","Table"] and e.text] | |
imgs = [os.path.join(FIGURES_DIR,f) for f in os.listdir(FIGURES_DIR) | |
if f.lower().endswith((".png",".jpg",".jpeg"))] | |
except: | |
texts = pages | |
imgs = [] | |
# split text chunks | |
from langchain.text_splitter import CharacterTextSplitter | |
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
chunks = [] | |
for t in texts: | |
chunks.extend(splitter.split_text(t)) | |
caps = [generate_caption(i) for i in imgs] | |
# embed | |
tvecs = embed_texts(chunks + caps) | |
ivecs = embed_images(imgs) | |
# align dims: captions embedded twice? simple: drop caps embeddings from tvecs | |
text_count = len(chunks) | |
cap_count = len(caps) | |
# use text embeddings for text and clip for images | |
combined_texts = chunks + caps | |
combined_vectors = tvecs[:text_count] + ivecs | |
# Build FAISS index from precomputed multimodal vectors: | |
index = FAISS.from_embeddings( | |
combined_vectors, | |
combined_texts, | |
) | |
retriever = index.as_retriever(search_kwargs={"k":2}) | |
status = f"β Indexed '{current_pdf_name}' β {len(chunks)} text chunks + {len(imgs)} images" | |
return current_pdf_name, status, gr.update(interactive=True) | |
def ask_question(pdf_name,question): | |
global retriever | |
if retriever is None: | |
return "β Please process a PDF first." | |
if not question.strip(): | |
return "β Enter a question." | |
docs = retriever.get_relevant_documents(question) | |
ctx = "\n\n".join(d.page_content for d in docs) | |
prompt = f"Use contexts:\n{ctx}\nQuestion:{question}\nAnswer:" | |
res = hf.chat_completion(model="google/gemma-3-27b-it",messages=[{"role":"user","content":prompt}],max_tokens=128) | |
return res["choices"][0]["message"]["content"].strip() | |
def generate_summary(): return ask_question(None,"Summarize:\n"+pdf_text[:2000]) | |
def extract_keywords(): return ask_question(None,"Extract keywords:\n"+pdf_text[:2000]) | |
def clear_interface(): | |
global retriever,combined_texts,combined_vectors,pdf_text | |
retriever=None | |
combined_texts=[] | |
combined_vectors=[] | |
pdf_text="" | |
shutil.rmtree(FIGURES_DIR,ignore_errors=True) | |
os.makedirs(FIGURES_DIR,exist_ok=True) | |
return None, "", gr.update(interactive=False) | |
# UI | |
theme=gr.themes.Soft(primary_hue="indigo",secondary_hue="blue") | |
with gr.Blocks(theme=theme) as demo: | |
gr.Markdown("# DocQueryAI (True Multimodal RAG)") | |
with gr.Row(): | |
with gr.Column(): | |
pdf_disp=gr.Textbox(label="Active Document",interactive=False) | |
pdf_file=gr.File(file_types=[".pdf"],type="filepath") | |
btn_process=gr.Button("Process PDF") | |
status=gr.Textbox(interactive=False) | |
with gr.Column(): | |
q_in=gr.Textbox(lines=3,interactive=False) | |
btn_ask=gr.Button("Ask") | |
ans=gr.Textbox(interactive=False) | |
btn_sum=gr.Button("Summary",interactive=False);out_sum=gr.Textbox(interactive=False) | |
btn_key=gr.Button("Keywords",interactive=False);out_key=gr.Textbox(interactive=False) | |
btn_clear=gr.Button("Clear All") | |
btn_process.click(process_pdf,[pdf_file],[pdf_disp,status,q_in]) | |
btn_ask.click(ask_question,[pdf_disp,q_in],ans) | |
btn_sum.click(generate_summary,[],out_sum) | |
btn_key.click(extract_keywords,[],out_key) | |
btn_clear.click(clear_interface,[],[pdf_disp,status,q_in]) | |
if __name__=="__main__": demo.launch() | |