Multimodal / app.py
Muzammil6376's picture
Update app.py
40696fb verified
raw
history blame
3.14 kB
import os
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import HuggingFaceHub
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from unstructured.partition.pdf import partition_pdf
from unstructured.partition.utils.constants import PartitionStrategy
from huggingface_hub import InferenceClient
from PIL import Image
# Directories
PDF_DIR = "pdfs"
FIGURE_DIR = "figures"
os.makedirs(PDF_DIR, exist_ok=True)
os.makedirs(FIGURE_DIR, exist_ok=True)
# Embeddings and Model Setup
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = FAISS.from_texts([], embedding_model)
llm = HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0.5, "max_length": 512})
template = """
Use the following context to answer the question. If the answer is unknown, say so.
Context: {context}
Question: {question}
Answer (3 sentences max):
"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=vector_store.as_retriever(),
chain_type_kwargs={"prompt": prompt}
)
# Hugging Face Inference API Client (for image captioning, etc.)
vision_model = InferenceClient("Salesforce/blip-image-captioning-base")
def extract_image_text(file_path):
with Image.open(file_path) as img:
caption = vision_model.image_to_text(img)
return caption
def process_pdf(file):
pdf_path = os.path.join(PDF_DIR, file.name)
with open(pdf_path, "wb") as f:
f.write(file.read())
elements = partition_pdf(
pdf_path,
strategy=PartitionStrategy.HI_RES,
extract_image_block_types=["Image", "Table"],
extract_image_block_output_dir=FIGURE_DIR
)
texts = [el.text for el in elements if el.category not in ["Image", "Table"]]
for fig_file in os.listdir(FIGURE_DIR):
fig_path = os.path.join(FIGURE_DIR, fig_file)
caption = extract_image_text(fig_path)
texts.append(caption)
full_text = "\n\n".join(texts)
# Chunking
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
docs = splitter.split_text(full_text)
vector_store.add_texts(docs)
return f"Processed {file.name} with {len(docs)} text chunks."
def answer_query(question):
return qa_chain.run(question)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# πŸ“„πŸ“· Multimodal RAG with Hugging Face")
with gr.Row():
file_input = gr.File(label="Upload PDF", type="file")
upload_btn = gr.Button("Process PDF")
status = gr.Textbox(label="Processing Status")
with gr.Row():
question = gr.Textbox(label="Ask a Question")
ask_btn = gr.Button("Get Answer")
answer_box = gr.Textbox(label="Answer")
upload_btn.click(fn=process_pdf, inputs=file_input, outputs=status)
ask_btn.click(fn=answer_query, inputs=question, outputs=answer_box)
demo.launch()