Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import tempfile | |
import base64 | |
from pathlib import Path | |
import io | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# ββ Globals βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
index = None | |
retriever = None | |
extracted_content = None | |
# ββ Inference & Embeddings βββββββββββββββββββββββββββββββββββββββββββββββββ | |
multimodal_client = InferenceClient(model="microsoft/Phi-3.5-vision-instruct") | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/clip-ViT-B-32") | |
# Temporary dirs for image extraction | |
TMP_DIR = tempfile.mkdtemp() | |
FIGURES_DIR = os.path.join(TMP_DIR, "figures") | |
os.makedirs(FIGURES_DIR, exist_ok=True) | |
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def encode_image_to_base64(image_path): | |
with open(image_path, "rb") as f: | |
return base64.b64encode(f.read()).decode() | |
def extract_images_from_pdf(pdf_path): | |
from fitz import open as fitz_open | |
from PIL import Image | |
import fitz | |
extracted = [] | |
descriptions = [] | |
try: | |
doc = fitz_open(pdf_path) | |
for p in range(len(doc)): | |
page = doc.load_page(p) | |
for img in page.get_images(): | |
xref = img[0] | |
pix = fitz.Pixmap(doc, xref) | |
if pix.n - pix.alpha < 4: | |
png = pix.tobytes("png") | |
img_pil = Image.open(io.BytesIO(png)) | |
fname = f"page_{p}_img_{xref}.png" | |
path = os.path.join(FIGURES_DIR, fname) | |
img_pil.save(path) | |
desc = analyze_image(path) | |
extracted.append(path) | |
descriptions.append(desc) | |
pix = None | |
doc.close() | |
except Exception as e: | |
print(f"Image extraction error: {e}") | |
return extracted, descriptions | |
def analyze_image(image_path): | |
try: | |
b64 = encode_image_to_base64(image_path) | |
prompt = ( | |
"Analyze this image and provide a detailed description. " | |
"Include any text, charts, tables, or important visual elements.\n" | |
"Image: [data]\nDescription:" | |
) | |
raw = multimodal_client.text_generation( | |
prompt=prompt, max_new_tokens=200, temperature=0.3 | |
) | |
# Handle dict or list wrapping | |
if isinstance(raw, dict): | |
out = raw.get("generated_text", str(raw)) | |
elif isinstance(raw, list) and raw and isinstance(raw[0], dict): | |
out = raw[0].get("generated_text", str(raw)) | |
else: | |
out = str(raw) | |
return f"[IMAGE]: {out.strip()}" | |
except Exception as e: | |
return f"[IMAGE ERROR]: {e}" | |
def process_pdf(pdf_file): | |
global index, retriever, extracted_content | |
if not pdf_file: | |
return None, "β Upload a PDF.", gr.update(interactive=False) | |
# clear old images | |
for f in os.listdir(FIGURES_DIR): | |
os.remove(os.path.join(FIGURES_DIR, f)) | |
path = pdf_file.name if isinstance(pdf_file, Path) else pdf_file | |
try: | |
import fitz | |
doc = fitz.open(path) | |
pages = [] | |
for i in range(len(doc)): | |
txt = doc.load_page(i).get_text().strip() | |
if txt: | |
pages.append(f"[Page {i+1}]\n" + txt) | |
doc.close() | |
imgs, descs = extract_images_from_pdf(path) | |
all_content = pages + descs | |
extracted_content = "\n\n".join(all_content) | |
if not extracted_content: | |
return pdf_file.name, "β No content extracted.", gr.update(interactive=False) | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, chunk_overlap=200, add_start_index=True | |
) | |
chunks = splitter.split_text(extracted_content) | |
index = FAISS.from_texts(chunks, embeddings) | |
retriever = index.as_retriever(search_kwargs={"k": 3}) | |
msg = f"β Processed {pdf_file.name} β {len(chunks)} chunks." | |
return pdf_file.name, msg, gr.update(interactive=True) | |
except Exception as e: | |
return pdf_file.name if pdf_file else None, f"β PDF error: {e}", gr.update(interactive=False) | |
def ask_question(doc_name, question): | |
global retriever | |
if not retriever: | |
return "β Process a PDF first." | |
if not question.strip(): | |
return "β Enter a question." | |
# retrieve | |
try: | |
docs = retriever.invoke(question) | |
except Exception: | |
docs = retriever.get_relevant_documents(question) | |
context = "\n\n".join(d.page_content for d in docs) | |
prompt = ( | |
"You are an AI assistant with both text and visual context.\n" | |
f"CONTEXT:\n{context}\nQUESTION: {question}\nAnswer:" | |
) | |
try: | |
raw = multimodal_client.text_generation( | |
prompt=prompt, max_new_tokens=300, temperature=0.5 | |
) | |
if isinstance(raw, dict): out = raw.get("generated_text", str(raw)) | |
elif isinstance(raw, list) and raw and isinstance(raw[0], dict): out = raw[0].get("generated_text", str(raw)) | |
else: out = str(raw) | |
return out.strip() | |
except Exception as e: | |
return f"β Generation error: {e}" | |
# ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="blue") | |
with gr.Blocks(theme=theme) as demo: | |
gr.Markdown("## π§ Unified MultiModal RAG") | |
with gr.Row(): | |
with gr.Column(): | |
pdf_in = gr.File(label="Upload PDF", file_types=[".pdf"], type="file") | |
proc_btn = gr.Button("π Process PDF", variant="primary") | |
pdf_disp = gr.Textbox(label="Active Doc", interactive=False) | |
status = gr.Textbox(label="Status", interactive=False) | |
with gr.Column(): | |
q_in = gr.Textbox(label="Ask your questionβ¦", lines=3, interactive=False) | |
ask_btn = gr.Button("π Ask", variant="primary", interactive=False) | |
ans_out = gr.Textbox(label="Answer", lines=6, interactive=False) | |
proc_btn.click(process_pdf, [pdf_in], [pdf_disp, status, q_in]) | |
# enable ask button only after processing | |
proc_btn.click(lambda *_: gr.update(interactive=True), [], [], [ask_btn]) | |
ask_btn.click(ask_question, [pdf_disp, q_in], ans_out) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |