Multimodal / app.py
Muzammil6376's picture
Update app.py
7f824f1 verified
raw
history blame
6.99 kB
# app.py
import os
import tempfile
import base64
from pathlib import Path
import io
import gradio as gr
from huggingface_hub import InferenceClient
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
# ── Globals ───────────────────────────────────────────────────────────────
index = None
retriever = None
extracted_content = None
# ── Inference & Embeddings ─────────────────────────────────────────────────
multimodal_client = InferenceClient(model="microsoft/Phi-3.5-vision-instruct")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/clip-ViT-B-32")
# Temporary dirs for image extraction
TMP_DIR = tempfile.mkdtemp()
FIGURES_DIR = os.path.join(TMP_DIR, "figures")
os.makedirs(FIGURES_DIR, exist_ok=True)
# ── Helpers ─────────────────────────────────────────────────────────────────
def encode_image_to_base64(image_path):
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode()
def extract_images_from_pdf(pdf_path):
from fitz import open as fitz_open
from PIL import Image
import fitz
extracted = []
descriptions = []
try:
doc = fitz_open(pdf_path)
for p in range(len(doc)):
page = doc.load_page(p)
for img in page.get_images():
xref = img[0]
pix = fitz.Pixmap(doc, xref)
if pix.n - pix.alpha < 4:
png = pix.tobytes("png")
img_pil = Image.open(io.BytesIO(png))
fname = f"page_{p}_img_{xref}.png"
path = os.path.join(FIGURES_DIR, fname)
img_pil.save(path)
desc = analyze_image(path)
extracted.append(path)
descriptions.append(desc)
pix = None
doc.close()
except Exception as e:
print(f"Image extraction error: {e}")
return extracted, descriptions
def analyze_image(image_path):
try:
b64 = encode_image_to_base64(image_path)
prompt = (
"Analyze this image and provide a detailed description. "
"Include any text, charts, tables, or important visual elements.\n"
"Image: [data]\nDescription:"
)
raw = multimodal_client.text_generation(
prompt=prompt, max_new_tokens=200, temperature=0.3
)
# Handle dict or list wrapping
if isinstance(raw, dict):
out = raw.get("generated_text", str(raw))
elif isinstance(raw, list) and raw and isinstance(raw[0], dict):
out = raw[0].get("generated_text", str(raw))
else:
out = str(raw)
return f"[IMAGE]: {out.strip()}"
except Exception as e:
return f"[IMAGE ERROR]: {e}"
def process_pdf(pdf_file):
global index, retriever, extracted_content
if not pdf_file:
return None, "❌ Upload a PDF.", gr.update(interactive=False)
# clear old images
for f in os.listdir(FIGURES_DIR):
os.remove(os.path.join(FIGURES_DIR, f))
path = pdf_file.name if isinstance(pdf_file, Path) else pdf_file
try:
import fitz
doc = fitz.open(path)
pages = []
for i in range(len(doc)):
txt = doc.load_page(i).get_text().strip()
if txt:
pages.append(f"[Page {i+1}]\n" + txt)
doc.close()
imgs, descs = extract_images_from_pdf(path)
all_content = pages + descs
extracted_content = "\n\n".join(all_content)
if not extracted_content:
return pdf_file.name, "❌ No content extracted.", gr.update(interactive=False)
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200, add_start_index=True
)
chunks = splitter.split_text(extracted_content)
index = FAISS.from_texts(chunks, embeddings)
retriever = index.as_retriever(search_kwargs={"k": 3})
msg = f"βœ… Processed {pdf_file.name} β€” {len(chunks)} chunks."
return pdf_file.name, msg, gr.update(interactive=True)
except Exception as e:
return pdf_file.name if pdf_file else None, f"❌ PDF error: {e}", gr.update(interactive=False)
def ask_question(doc_name, question):
global retriever
if not retriever:
return "❌ Process a PDF first."
if not question.strip():
return "❌ Enter a question."
# retrieve
try:
docs = retriever.invoke(question)
except Exception:
docs = retriever.get_relevant_documents(question)
context = "\n\n".join(d.page_content for d in docs)
prompt = (
"You are an AI assistant with both text and visual context.\n"
f"CONTEXT:\n{context}\nQUESTION: {question}\nAnswer:"
)
try:
raw = multimodal_client.text_generation(
prompt=prompt, max_new_tokens=300, temperature=0.5
)
if isinstance(raw, dict): out = raw.get("generated_text", str(raw))
elif isinstance(raw, list) and raw and isinstance(raw[0], dict): out = raw[0].get("generated_text", str(raw))
else: out = str(raw)
return out.strip()
except Exception as e:
return f"❌ Generation error: {e}"
# ── Gradio UI ───────────────────────────────────────────────────────────────
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="blue")
with gr.Blocks(theme=theme) as demo:
gr.Markdown("## 🧠 Unified MultiModal RAG")
with gr.Row():
with gr.Column():
pdf_in = gr.File(label="Upload PDF", file_types=[".pdf"], type="file")
proc_btn = gr.Button("πŸ”„ Process PDF", variant="primary")
pdf_disp = gr.Textbox(label="Active Doc", interactive=False)
status = gr.Textbox(label="Status", interactive=False)
with gr.Column():
q_in = gr.Textbox(label="Ask your question…", lines=3, interactive=False)
ask_btn = gr.Button("πŸ” Ask", variant="primary", interactive=False)
ans_out = gr.Textbox(label="Answer", lines=6, interactive=False)
proc_btn.click(process_pdf, [pdf_in], [pdf_disp, status, q_in])
# enable ask button only after processing
proc_btn.click(lambda *_: gr.update(interactive=True), [], [], [ask_btn])
ask_btn.click(ask_question, [pdf_disp, q_in], ans_out)
if __name__ == "__main__":
demo.launch(debug=True)