File size: 6,988 Bytes
fd644c0
67a56f6
a6c0d87
919ab87
7f824f1
e7f1d86
 
fd644c0
 
3fdd093
fd644c0
ced2810
225229c
7f824f1
fd644c0
 
 
87baec5
7f824f1
919ab87
 
009e0ad
7f824f1
 
 
 
87baec5
7f824f1
919ab87
7f824f1
 
 
919ab87
7f824f1
 
 
 
 
 
 
b42840f
7f824f1
 
 
 
b42840f
7f824f1
fd644c0
7f824f1
 
 
 
 
 
 
 
fd644c0
7f824f1
b42840f
7f824f1
 
 
b42840f
7f824f1
87baec5
fd644c0
 
7f824f1
 
 
919ab87
7f824f1
fd644c0
 
7f824f1
 
 
 
 
 
 
 
87baec5
7f824f1
 
87baec5
7f824f1
 
 
 
87baec5
7f824f1
 
 
fd644c0
7f824f1
87baec5
7f824f1
 
 
 
 
 
 
 
 
 
 
a6c0d87
fd644c0
7f824f1
fd644c0
 
 
a6c0d87
fd644c0
a6c0d87
 
fd644c0
7f824f1
 
fd644c0
87baec5
7f824f1
 
87baec5
7f824f1
fd644c0
 
7f824f1
87baec5
7f824f1
919ab87
7f824f1
fd644c0
 
7f824f1
 
 
 
 
 
 
 
87baec5
7f824f1
 
87baec5
7f824f1
 
 
 
87baec5
7f824f1
87baec5
7f824f1
87baec5
7f824f1
 
87baec5
 
7f824f1
 
 
 
87baec5
7f824f1
 
 
7fdd092
7f824f1
 
 
 
d179e57
3fdd093
7f824f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# app.py
import os
import tempfile
import base64
from pathlib import Path
import io

import gradio as gr
from huggingface_hub import InferenceClient
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter

# ── Globals ───────────────────────────────────────────────────────────────
index = None
retriever = None
extracted_content = None

# ── Inference & Embeddings ─────────────────────────────────────────────────
multimodal_client = InferenceClient(model="microsoft/Phi-3.5-vision-instruct")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/clip-ViT-B-32")

# Temporary dirs for image extraction
TMP_DIR = tempfile.mkdtemp()
FIGURES_DIR = os.path.join(TMP_DIR, "figures")
os.makedirs(FIGURES_DIR, exist_ok=True)

# ── Helpers ─────────────────────────────────────────────────────────────────
def encode_image_to_base64(image_path):
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode()


def extract_images_from_pdf(pdf_path):
    from fitz import open as fitz_open
    from PIL import Image
    import fitz

    extracted = []
    descriptions = []
    try:
        doc = fitz_open(pdf_path)
        for p in range(len(doc)):
            page = doc.load_page(p)
            for img in page.get_images():
                xref = img[0]
                pix = fitz.Pixmap(doc, xref)
                if pix.n - pix.alpha < 4:
                    png = pix.tobytes("png")
                    img_pil = Image.open(io.BytesIO(png))
                    fname = f"page_{p}_img_{xref}.png"
                    path = os.path.join(FIGURES_DIR, fname)
                    img_pil.save(path)
                    desc = analyze_image(path)
                    extracted.append(path)
                    descriptions.append(desc)
                pix = None
        doc.close()
    except Exception as e:
        print(f"Image extraction error: {e}")
    return extracted, descriptions


def analyze_image(image_path):
    try:
        b64 = encode_image_to_base64(image_path)
        prompt = (
            "Analyze this image and provide a detailed description. "
            "Include any text, charts, tables, or important visual elements.\n"
            "Image: [data]\nDescription:"
        )
        raw = multimodal_client.text_generation(
            prompt=prompt, max_new_tokens=200, temperature=0.3
        )
        # Handle dict or list wrapping
        if isinstance(raw, dict):
            out = raw.get("generated_text", str(raw))
        elif isinstance(raw, list) and raw and isinstance(raw[0], dict):
            out = raw[0].get("generated_text", str(raw))
        else:
            out = str(raw)
        return f"[IMAGE]: {out.strip()}"
    except Exception as e:
        return f"[IMAGE ERROR]: {e}"


def process_pdf(pdf_file):
    global index, retriever, extracted_content
    if not pdf_file:
        return None, "❌ Upload a PDF.", gr.update(interactive=False)

    # clear old images
    for f in os.listdir(FIGURES_DIR):
        os.remove(os.path.join(FIGURES_DIR, f))

    path = pdf_file.name if isinstance(pdf_file, Path) else pdf_file  
    try:
        import fitz
        doc = fitz.open(path)
        pages = []
        for i in range(len(doc)):
            txt = doc.load_page(i).get_text().strip()
            if txt:
                pages.append(f"[Page {i+1}]\n" + txt)
        doc.close()

        imgs, descs = extract_images_from_pdf(path)
        all_content = pages + descs
        extracted_content = "\n\n".join(all_content)
        if not extracted_content:
            return pdf_file.name, "❌ No content extracted.", gr.update(interactive=False)

        splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000, chunk_overlap=200, add_start_index=True
        )
        chunks = splitter.split_text(extracted_content)
        index = FAISS.from_texts(chunks, embeddings)
        retriever = index.as_retriever(search_kwargs={"k": 3})

        msg = f"βœ… Processed {pdf_file.name} β€” {len(chunks)} chunks."
        return pdf_file.name, msg, gr.update(interactive=True)

    except Exception as e:
        return pdf_file.name if pdf_file else None, f"❌ PDF error: {e}", gr.update(interactive=False)


def ask_question(doc_name, question):
    global retriever
    if not retriever:
        return "❌ Process a PDF first."
    if not question.strip():
        return "❌ Enter a question."

    # retrieve
    try:
        docs = retriever.invoke(question)
    except Exception:
        docs = retriever.get_relevant_documents(question)

    context = "\n\n".join(d.page_content for d in docs)
    prompt = (
        "You are an AI assistant with both text and visual context.\n"
        f"CONTEXT:\n{context}\nQUESTION: {question}\nAnswer:"
    )
    try:
        raw = multimodal_client.text_generation(
            prompt=prompt, max_new_tokens=300, temperature=0.5
        )
        if isinstance(raw, dict): out = raw.get("generated_text", str(raw))
        elif isinstance(raw, list) and raw and isinstance(raw[0], dict): out = raw[0].get("generated_text", str(raw))
        else: out = str(raw)
        return out.strip()
    except Exception as e:
        return f"❌ Generation error: {e}"

# ── Gradio UI ───────────────────────────────────────────────────────────────
theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="blue")
with gr.Blocks(theme=theme) as demo:
    gr.Markdown("## 🧠 Unified MultiModal RAG")
    with gr.Row():
        with gr.Column():
            pdf_in = gr.File(label="Upload PDF", file_types=[".pdf"], type="file")
            proc_btn = gr.Button("πŸ”„ Process PDF", variant="primary")
            pdf_disp = gr.Textbox(label="Active Doc", interactive=False)
            status  = gr.Textbox(label="Status", interactive=False)
        with gr.Column():
            q_in = gr.Textbox(label="Ask your question…", lines=3, interactive=False)
            ask_btn = gr.Button("πŸ” Ask", variant="primary", interactive=False)
            ans_out = gr.Textbox(label="Answer", lines=6, interactive=False)

    proc_btn.click(process_pdf, [pdf_in], [pdf_disp, status, q_in])
    # enable ask button only after processing
    proc_btn.click(lambda *_: gr.update(interactive=True), [], [], [ask_btn])
    ask_btn.click(ask_question, [pdf_disp, q_in], ans_out)

if __name__ == "__main__":
    demo.launch(debug=True)