Multimodal / app.py
Muzammil6376's picture
Update app.py
dcc36ef verified
raw
history blame
3.65 kB
# app.py
import os
from pathlib import Path
import gradio as gr
from PIL import Image
from huggingface_hub import InferenceClient
# βœ… Community imports
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFaceEndpoint
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from unstructured.partition.pdf import partition_pdf
from unstructured.partition.utils.constants import PartitionStrategy
# β€”β€”β€”β€”β€” Config & Folders β€”β€”β€”β€”β€”
PDF_DIR = Path("pdfs")
FIG_DIR = Path("figures")
PDF_DIR.mkdir(exist_ok=True)
FIG_DIR.mkdir(exist_ok=True)
# β€”β€”β€”β€”β€” Read your HF_TOKEN secret β€”β€”β€”β€”β€”
hf_token = os.environ["HF_TOKEN"]
# β€”β€”β€”β€”β€” Embeddings & LLM Setup β€”β€”β€”β€”β€”
embedding_model = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# LLM via HF Inference API endpoint
llm = HuggingFaceEndpoint(
endpoint_url="https://api-inference.huggingface.co/models/google/flan-t5-base",
huggingfacehub_api_token=hf_token,
temperature=0.5,
max_length=512,
)
# Prompt
TEMPLATE = """
Use the following context to answer the question. If unknown, say so.
Context: {context}
Question: {question}
Answer (up to 3 sentences):
"""
prompt = PromptTemplate(template=TEMPLATE, input_variables=["context", "question"])
# Inference client for image captioning
vision_client = InferenceClient(
repo_id="Salesforce/blip-image-captioning-base",
token=hf_token,
)
# Globals (will initialize after processing)
vector_store = None
qa_chain = None
def extract_image_caption(path: str) -> str:
with Image.open(path) as img:
return vision_client.image_to_text(img)
def process_pdf(pdf_file) -> str:
global vector_store, qa_chain
out_path = PDF_DIR / pdf_file.name
with open(out_path, "wb") as f:
f.write(pdf_file.read())
elems = partition_pdf(
str(out_path),
strategy=PartitionStrategy.HI_RES,
extract_image_block_types=["Image", "Table"],
extract_image_block_output_dir=str(FIG_DIR),
)
texts = [el.text for el in elems if el.category not in ("Image", "Table")]
for img_file in FIG_DIR.iterdir():
texts.append(extract_image_caption(str(img_file)))
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
docs = splitter.split_text("\n\n".join(texts))
vector_store = FAISS.from_texts(docs, embedding_model)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=vector_store.as_retriever(),
chain_type_kwargs={"prompt": prompt},
)
return f"βœ… Processed `{pdf_file.name}` into {len(docs)} chunks."
def answer_query(question: str) -> str:
if qa_chain is None:
return "❗ Please upload and process a PDF first."
return qa_chain.run(question)
# β€”β€”β€”β€”β€” Gradio UI β€”β€”β€”β€”β€”
with gr.Blocks() as demo:
gr.Markdown("## πŸ“„πŸ“· Multimodal RAG β€” Hugging Face Spaces")
with gr.Row():
pdf_in = gr.File(label="Upload PDF", type="file")
btn_proc = gr.Button("Process PDF")
status = gr.Textbox(label="Status")
with gr.Row():
q_in = gr.Textbox(label="Your Question")
btn_ask = gr.Button("Ask")
ans_out = gr.Textbox(label="Answer")
btn_proc.click(fn=process_pdf, inputs=pdf_in, outputs=status)
btn_ask.click(fn=answer_query, inputs=q_in, outputs=ans_out)
if __name__ == "__main__":
demo.launch()