import os import gradio as gr from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import FAISS from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.llms import HuggingFaceHub from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from unstructured.partition.pdf import partition_pdf from unstructured.partition.utils.constants import PartitionStrategy from huggingface_hub import InferenceClient from PIL import Image # Directories PDF_DIR = "pdfs" FIGURE_DIR = "figures" os.makedirs(PDF_DIR, exist_ok=True) os.makedirs(FIGURE_DIR, exist_ok=True) # Embeddings and Model Setup embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") vector_store = FAISS.from_texts([], embedding_model) llm = HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0.5, "max_length": 512}) template = """ Use the following context to answer the question. If the answer is unknown, say so. Context: {context} Question: {question} Answer (3 sentences max): """ prompt = PromptTemplate(template=template, input_variables=["context", "question"]) qa_chain = RetrievalQA.from_chain_type( llm=llm, retriever=vector_store.as_retriever(), chain_type_kwargs={"prompt": prompt} ) # Hugging Face Inference API Client (for image captioning, etc.) vision_model = InferenceClient("Salesforce/blip-image-captioning-base") def extract_image_text(file_path): with Image.open(file_path) as img: caption = vision_model.image_to_text(img) return caption def process_pdf(file): pdf_path = os.path.join(PDF_DIR, file.name) with open(pdf_path, "wb") as f: f.write(file.read()) elements = partition_pdf( pdf_path, strategy=PartitionStrategy.HI_RES, extract_image_block_types=["Image", "Table"], extract_image_block_output_dir=FIGURE_DIR ) texts = [el.text for el in elements if el.category not in ["Image", "Table"]] for fig_file in os.listdir(FIGURE_DIR): fig_path = os.path.join(FIGURE_DIR, fig_file) caption = extract_image_text(fig_path) texts.append(caption) full_text = "\n\n".join(texts) # Chunking splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) docs = splitter.split_text(full_text) vector_store.add_texts(docs) return f"Processed {file.name} with {len(docs)} text chunks." def answer_query(question): return qa_chain.run(question) # Gradio UI with gr.Blocks() as demo: gr.Markdown("# 📄📷 Multimodal RAG with Hugging Face") with gr.Row(): file_input = gr.File(label="Upload PDF", type="file") upload_btn = gr.Button("Process PDF") status = gr.Textbox(label="Processing Status") with gr.Row(): question = gr.Textbox(label="Ask a Question") ask_btn = gr.Button("Get Answer") answer_box = gr.Textbox(label="Answer") upload_btn.click(fn=process_pdf, inputs=file_input, outputs=status) ask_btn.click(fn=answer_query, inputs=question, outputs=answer_box) demo.launch()