Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.llms import HuggingFaceHub | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from unstructured.partition.pdf import partition_pdf | |
from unstructured.partition.utils.constants import PartitionStrategy | |
from huggingface_hub import InferenceClient | |
from PIL import Image | |
# Directories | |
PDF_DIR = "pdfs" | |
FIGURE_DIR = "figures" | |
os.makedirs(PDF_DIR, exist_ok=True) | |
os.makedirs(FIGURE_DIR, exist_ok=True) | |
# Embeddings and Model Setup | |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vector_store = FAISS.from_texts([], embedding_model) | |
llm = HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0.5, "max_length": 512}) | |
template = """ | |
Use the following context to answer the question. If the answer is unknown, say so. | |
Context: {context} | |
Question: {question} | |
Answer (3 sentences max): | |
""" | |
prompt = PromptTemplate(template=template, input_variables=["context", "question"]) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
retriever=vector_store.as_retriever(), | |
chain_type_kwargs={"prompt": prompt} | |
) | |
# Hugging Face Inference API Client (for image captioning, etc.) | |
vision_model = InferenceClient("Salesforce/blip-image-captioning-base") | |
def extract_image_text(file_path): | |
with Image.open(file_path) as img: | |
caption = vision_model.image_to_text(img) | |
return caption | |
def process_pdf(file): | |
pdf_path = os.path.join(PDF_DIR, file.name) | |
with open(pdf_path, "wb") as f: | |
f.write(file.read()) | |
elements = partition_pdf( | |
pdf_path, | |
strategy=PartitionStrategy.HI_RES, | |
extract_image_block_types=["Image", "Table"], | |
extract_image_block_output_dir=FIGURE_DIR | |
) | |
texts = [el.text for el in elements if el.category not in ["Image", "Table"]] | |
for fig_file in os.listdir(FIGURE_DIR): | |
fig_path = os.path.join(FIGURE_DIR, fig_file) | |
caption = extract_image_text(fig_path) | |
texts.append(caption) | |
full_text = "\n\n".join(texts) | |
# Chunking | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
docs = splitter.split_text(full_text) | |
vector_store.add_texts(docs) | |
return f"Processed {file.name} with {len(docs)} text chunks." | |
def answer_query(question): | |
return qa_chain.run(question) | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# ππ· Multimodal RAG with Hugging Face") | |
with gr.Row(): | |
file_input = gr.File(label="Upload PDF", type="file") | |
upload_btn = gr.Button("Process PDF") | |
status = gr.Textbox(label="Processing Status") | |
with gr.Row(): | |
question = gr.Textbox(label="Ask a Question") | |
ask_btn = gr.Button("Get Answer") | |
answer_box = gr.Textbox(label="Answer") | |
upload_btn.click(fn=process_pdf, inputs=file_input, outputs=status) | |
ask_btn.click(fn=answer_query, inputs=question, outputs=answer_box) | |
demo.launch() | |