File size: 3,144 Bytes
67a56f6
 
ced2810
 
 
 
40696fb
 
 
 
 
67a56f6
225229c
40696fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced2810
225229c
40696fb
 
ae644bf
40696fb
 
 
 
2a4ba68
40696fb
 
 
 
2a4ba68
40696fb
 
 
 
 
 
d179e57
40696fb
d179e57
40696fb
 
 
 
225229c
40696fb
225229c
40696fb
 
 
 
2a4ba68
40696fb
d179e57
40696fb
 
ced2810
 
 
40696fb
d179e57
 
40696fb
 
 
d179e57
40696fb
 
 
 
d179e57
40696fb
 
d179e57
40696fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import HuggingFaceHub
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from unstructured.partition.pdf import partition_pdf
from unstructured.partition.utils.constants import PartitionStrategy
from huggingface_hub import InferenceClient
from PIL import Image

# Directories
PDF_DIR = "pdfs"
FIGURE_DIR = "figures"
os.makedirs(PDF_DIR, exist_ok=True)
os.makedirs(FIGURE_DIR, exist_ok=True)

# Embeddings and Model Setup
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = FAISS.from_texts([], embedding_model)

llm = HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0.5, "max_length": 512})

template = """
Use the following context to answer the question. If the answer is unknown, say so.
Context: {context}
Question: {question}
Answer (3 sentences max):
"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])

qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=vector_store.as_retriever(),
    chain_type_kwargs={"prompt": prompt}
)

# Hugging Face Inference API Client (for image captioning, etc.)
vision_model = InferenceClient("Salesforce/blip-image-captioning-base")

def extract_image_text(file_path):
    with Image.open(file_path) as img:
        caption = vision_model.image_to_text(img)
    return caption

def process_pdf(file):
    pdf_path = os.path.join(PDF_DIR, file.name)
    with open(pdf_path, "wb") as f:
        f.write(file.read())

    elements = partition_pdf(
        pdf_path,
        strategy=PartitionStrategy.HI_RES,
        extract_image_block_types=["Image", "Table"],
        extract_image_block_output_dir=FIGURE_DIR
    )

    texts = [el.text for el in elements if el.category not in ["Image", "Table"]]

    for fig_file in os.listdir(FIGURE_DIR):
        fig_path = os.path.join(FIGURE_DIR, fig_file)
        caption = extract_image_text(fig_path)
        texts.append(caption)

    full_text = "\n\n".join(texts)

    # Chunking
    splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    docs = splitter.split_text(full_text)
    vector_store.add_texts(docs)

    return f"Processed {file.name} with {len(docs)} text chunks."

def answer_query(question):
    return qa_chain.run(question)

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# πŸ“„πŸ“· Multimodal RAG with Hugging Face")

    with gr.Row():
        file_input = gr.File(label="Upload PDF", type="file")
        upload_btn = gr.Button("Process PDF")
        status = gr.Textbox(label="Processing Status")

    with gr.Row():
        question = gr.Textbox(label="Ask a Question")
        ask_btn = gr.Button("Get Answer")
        answer_box = gr.Textbox(label="Answer")

    upload_btn.click(fn=process_pdf, inputs=file_input, outputs=status)
    ask_btn.click(fn=answer_query, inputs=question, outputs=answer_box)

demo.launch()