# app.py import os import tempfile import base64 from pathlib import Path import io import gradio as gr from huggingface_hub import InferenceClient from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter # ── Globals ─────────────────────────────────────────────────────────────── index = None retriever = None extracted_content = None # ── Inference & Embeddings ───────────────────────────────────────────────── multimodal_client = InferenceClient(model="microsoft/Phi-3.5-vision-instruct") embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/clip-ViT-B-32") # Temporary dirs for image extraction TMP_DIR = tempfile.mkdtemp() FIGURES_DIR = os.path.join(TMP_DIR, "figures") os.makedirs(FIGURES_DIR, exist_ok=True) # ── Helpers ───────────────────────────────────────────────────────────────── def encode_image_to_base64(image_path): with open(image_path, "rb") as f: return base64.b64encode(f.read()).decode() def extract_images_from_pdf(pdf_path): from fitz import open as fitz_open from PIL import Image import fitz extracted = [] descriptions = [] try: doc = fitz_open(pdf_path) for p in range(len(doc)): page = doc.load_page(p) for img in page.get_images(): xref = img[0] pix = fitz.Pixmap(doc, xref) if pix.n - pix.alpha < 4: png = pix.tobytes("png") img_pil = Image.open(io.BytesIO(png)) fname = f"page_{p}_img_{xref}.png" path = os.path.join(FIGURES_DIR, fname) img_pil.save(path) desc = analyze_image(path) extracted.append(path) descriptions.append(desc) pix = None doc.close() except Exception as e: print(f"Image extraction error: {e}") return extracted, descriptions def analyze_image(image_path): try: b64 = encode_image_to_base64(image_path) prompt = ( "Analyze this image and provide a detailed description. " "Include any text, charts, tables, or important visual elements.\n" "Image: [data]\nDescription:" ) raw = multimodal_client.text_generation( prompt=prompt, max_new_tokens=200, temperature=0.3 ) # Handle dict or list wrapping if isinstance(raw, dict): out = raw.get("generated_text", str(raw)) elif isinstance(raw, list) and raw and isinstance(raw[0], dict): out = raw[0].get("generated_text", str(raw)) else: out = str(raw) return f"[IMAGE]: {out.strip()}" except Exception as e: return f"[IMAGE ERROR]: {e}" def process_pdf(pdf_file): global index, retriever, extracted_content if not pdf_file: return None, "❌ Upload a PDF.", gr.update(interactive=False) # clear old images for f in os.listdir(FIGURES_DIR): os.remove(os.path.join(FIGURES_DIR, f)) path = pdf_file.name if isinstance(pdf_file, Path) else pdf_file try: import fitz doc = fitz.open(path) pages = [] for i in range(len(doc)): txt = doc.load_page(i).get_text().strip() if txt: pages.append(f"[Page {i+1}]\n" + txt) doc.close() imgs, descs = extract_images_from_pdf(path) all_content = pages + descs extracted_content = "\n\n".join(all_content) if not extracted_content: return pdf_file.name, "❌ No content extracted.", gr.update(interactive=False) splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, add_start_index=True ) chunks = splitter.split_text(extracted_content) index = FAISS.from_texts(chunks, embeddings) retriever = index.as_retriever(search_kwargs={"k": 3}) msg = f"✅ Processed {pdf_file.name} — {len(chunks)} chunks." return pdf_file.name, msg, gr.update(interactive=True) except Exception as e: return pdf_file.name if pdf_file else None, f"❌ PDF error: {e}", gr.update(interactive=False) def ask_question(doc_name, question): global retriever if not retriever: return "❌ Process a PDF first." if not question.strip(): return "❌ Enter a question." # retrieve try: docs = retriever.invoke(question) except Exception: docs = retriever.get_relevant_documents(question) context = "\n\n".join(d.page_content for d in docs) prompt = ( "You are an AI assistant with both text and visual context.\n" f"CONTEXT:\n{context}\nQUESTION: {question}\nAnswer:" ) try: raw = multimodal_client.text_generation( prompt=prompt, max_new_tokens=300, temperature=0.5 ) if isinstance(raw, dict): out = raw.get("generated_text", str(raw)) elif isinstance(raw, list) and raw and isinstance(raw[0], dict): out = raw[0].get("generated_text", str(raw)) else: out = str(raw) return out.strip() except Exception as e: return f"❌ Generation error: {e}" # ── Gradio UI ─────────────────────────────────────────────────────────────── theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="blue") with gr.Blocks(theme=theme) as demo: gr.Markdown("## 🧠 Unified MultiModal RAG") with gr.Row(): with gr.Column(): pdf_in = gr.File(label="Upload PDF", file_types=[".pdf"], type="file") proc_btn = gr.Button("🔄 Process PDF", variant="primary") pdf_disp = gr.Textbox(label="Active Doc", interactive=False) status = gr.Textbox(label="Status", interactive=False) with gr.Column(): q_in = gr.Textbox(label="Ask your question…", lines=3, interactive=False) ask_btn = gr.Button("🔍 Ask", variant="primary", interactive=False) ans_out = gr.Textbox(label="Answer", lines=6, interactive=False) proc_btn.click(process_pdf, [pdf_in], [pdf_disp, status, q_in]) # enable ask button only after processing proc_btn.click(lambda *_: gr.update(interactive=True), [], [], [ask_btn]) ask_btn.click(ask_question, [pdf_disp, q_in], ans_out) if __name__ == "__main__": demo.launch(debug=True)