File size: 6,986 Bytes
67a56f6
 
2a4ba68
 
225229c
67a56f6
 
2a4ba68
225229c
67a56f6
 
 
fc3f336
225229c
fc3f336
225229c
fc3f336
67a56f6
2a4ba68
225229c
 
 
 
61f23e8
ae644bf
225229c
 
 
 
 
67a56f6
225229c
67a56f6
 
 
225229c
 
67a56f6
583b178
225229c
 
583b178
67a56f6
225229c
 
 
 
67a56f6
 
 
583b178
67a56f6
 
583b178
 
 
 
225229c
67a56f6
 
225229c
 
 
 
 
 
 
 
 
 
ae644bf
225229c
 
ae644bf
2a4ba68
ae644bf
2a4ba68
225229c
2a4ba68
 
 
 
ae644bf
225229c
ae644bf
225229c
2a4ba68
ae644bf
225229c
ae644bf
 
225229c
 
 
 
 
 
 
 
2a4ba68
 
 
225229c
2a4ba68
225229c
 
 
 
 
 
 
 
 
 
 
 
6b26092
 
 
 
 
6d145b6
225229c
0a76168
67a56f6
 
225229c
67a56f6
 
225229c
67a56f6
225229c
67a56f6
225229c
 
 
 
2a4ba68
 
225229c
 
 
 
67a56f6
225229c
 
 
 
 
 
 
2a4ba68
 
225229c
 
 
 
67a56f6
 
225229c
 
 
 
67a56f6
225229c
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import shutil
from typing import List

import torch
import gradio as gr
from PIL import Image

# Unstructured for PDF parsing
from unstructured.partition.pdf import partition_pdf
from unstructured.partition.utils.constants import PartitionStrategy

# Vision-language captioning (BLIP)
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel

# Hugging Face Inference client for LLM
from huggingface_hub import InferenceClient

# FAISS vectorstore
from langchain_community.vectorstores import FAISS

# Text embeddings
from langchain_huggingface import HuggingFaceEmbeddings

# ── Globals ───────────────────────────────────────────────────────────────────
retriever = None
current_pdf_name = None
combined_texts: List[str] = []  # text chunks + captions
combined_vectors: List[List[float]] = []
pdf_text: str = ""

# ── Setup ─────────────────────────────────────────────────────────────────────
FIGURES_DIR = "figures"
if os.path.exists(FIGURES_DIR):
    shutil.rmtree(FIGURES_DIR)
else:
    os.makedirs(FIGURES_DIR, exist_ok=True)

# ── Clients & Models ───────────────────────────────────────────────────────────
hf = InferenceClient()  # for chat completions
txt_emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")


def generate_caption(image_path: str) -> str:
    image = Image.open(image_path).convert("RGB")
    inputs = blip_processor(image, return_tensors="pt")
    out = blip_model.generate(**inputs)
    return blip_processor.decode(out[0], skip_special_tokens=True)


def embed_texts(texts: List[str]) -> List[List[float]]:
    return txt_emb.embed_documents(texts)


def embed_images(image_paths: List[str]) -> List[List[float]]:
    feats = []
    for p in image_paths:
        img = Image.open(p).convert("RGB")
        inputs = clip_processor(images=img, return_tensors="pt")
        with torch.no_grad():
            v = clip_model.get_image_features(**inputs)
        feats.append(v[0].cpu().tolist())
    return feats


def process_pdf(pdf_file):
    global retriever, current_pdf_name, combined_texts, combined_vectors, pdf_text
    if pdf_file is None:
        return None, "❌ Please upload a PDF file.", gr.update(interactive=False)

    current_pdf_name = os.path.basename(pdf_file.name)
    # extract full text
    from pypdf import PdfReader
    reader = PdfReader(pdf_file.name)
    pages = [page.extract_text() or "" for page in reader.pages]
    pdf_text = "\n\n".join(pages)

    # rich parsing for images
    try:
        els = partition_pdf(
            filename=pdf_file.name,
            strategy=PartitionStrategy.HI_RES,
            extract_image_block_types=["Image","Table"],
            extract_image_block_output_dir=FIGURES_DIR,
        )
        texts = [e.text for e in els if e.category not in ["Image","Table"] and e.text]
        imgs = [os.path.join(FIGURES_DIR,f) for f in os.listdir(FIGURES_DIR)
                if f.lower().endswith((".png",".jpg",".jpeg"))]
    except:
        texts = pages
        imgs = []

    # split text chunks
    from langchain.text_splitter import CharacterTextSplitter
    splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    chunks = []
    for t in texts:
        chunks.extend(splitter.split_text(t))
    caps = [generate_caption(i) for i in imgs]

    # embed
    tvecs = embed_texts(chunks + caps)
    ivecs = embed_images(imgs)
    # align dims: captions embedded twice? simple: drop caps embeddings from tvecs
    text_count = len(chunks)
    cap_count = len(caps)
    # use text embeddings for text and clip for images
    combined_texts = chunks + caps
    combined_vectors = tvecs[:text_count] + ivecs

    # Build FAISS index from precomputed multimodal vectors:
    index = FAISS.from_embeddings(
        combined_vectors,
        combined_texts,
    )
    retriever = index.as_retriever(search_kwargs={"k":2})
    status = f"βœ… Indexed '{current_pdf_name}' β€” {len(chunks)} text chunks + {len(imgs)} images"
    return current_pdf_name, status, gr.update(interactive=True)


def ask_question(pdf_name,question):
    global retriever
    if retriever is None:
        return "❌ Please process a PDF first."
    if not question.strip():
        return "❌ Enter a question."
    docs = retriever.get_relevant_documents(question)
    ctx = "\n\n".join(d.page_content for d in docs)
    prompt = f"Use contexts:\n{ctx}\nQuestion:{question}\nAnswer:"
    res = hf.chat_completion(model="google/gemma-3-27b-it",messages=[{"role":"user","content":prompt}],max_tokens=128)
    return res["choices"][0]["message"]["content"].strip()


def generate_summary(): return ask_question(None,"Summarize:\n"+pdf_text[:2000])

def extract_keywords(): return ask_question(None,"Extract keywords:\n"+pdf_text[:2000])

def clear_interface():
    global retriever,combined_texts,combined_vectors,pdf_text
    retriever=None
    combined_texts=[]
    combined_vectors=[]
    pdf_text=""
    shutil.rmtree(FIGURES_DIR,ignore_errors=True)
    os.makedirs(FIGURES_DIR,exist_ok=True)
    return None, "", gr.update(interactive=False)

# UI
theme=gr.themes.Soft(primary_hue="indigo",secondary_hue="blue")
with gr.Blocks(theme=theme) as demo:
    gr.Markdown("# DocQueryAI (True Multimodal RAG)")
    with gr.Row():
        with gr.Column():
            pdf_disp=gr.Textbox(label="Active Document",interactive=False)
            pdf_file=gr.File(file_types=[".pdf"],type="filepath")
            btn_process=gr.Button("Process PDF")
            status=gr.Textbox(interactive=False)
        with gr.Column():
            q_in=gr.Textbox(lines=3,interactive=False)
            btn_ask=gr.Button("Ask")
            ans=gr.Textbox(interactive=False)
    btn_sum=gr.Button("Summary",interactive=False);out_sum=gr.Textbox(interactive=False)
    btn_key=gr.Button("Keywords",interactive=False);out_key=gr.Textbox(interactive=False)
    btn_clear=gr.Button("Clear All")
    btn_process.click(process_pdf,[pdf_file],[pdf_disp,status,q_in])
    btn_ask.click(ask_question,[pdf_disp,q_in],ans)
    btn_sum.click(generate_summary,[],out_sum)
    btn_key.click(extract_keywords,[],out_key)
    btn_clear.click(clear_interface,[],[pdf_disp,status,q_in])
if __name__=="__main__": demo.launch()