Muzammil6376 commited on
Commit
40696fb
·
verified ·
1 Parent(s): ced2810

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -68
app.py CHANGED
@@ -1,95 +1,96 @@
1
  import os
2
- import tempfile
3
-
4
  import gradio as gr
5
  from langchain.embeddings import HuggingFaceEmbeddings
6
  from langchain.vectorstores import FAISS
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
- from langchain.document_loaders import UnstructuredPDFLoader
9
- from langchain.chains import RetrievalQA
10
  from langchain.llms import HuggingFaceHub
 
 
 
 
 
11
  from PIL import Image
12
- from transformers import pipeline
13
-
14
- # Directories for temporary storage
15
- FIGURES_DIR = tempfile.mkdtemp(prefix="figures_")
16
-
17
- # Configure Hugging Face
18
- HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
19
-
20
- # Initialize embeddings and vector store
21
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
22
- vector_store = None
23
-
24
- # Initialize image captioning pipeline
25
- captioner = pipeline("image-to-text", model="Salesforce/blip2-flan-t5-xl", use_auth_token=HUGGINGFACEHUB_API_TOKEN)
26
 
27
- # Initialize LLM for QA
28
- llm = HuggingFaceHub(
29
- repo_id="google/flan-t5-xxl",
30
- model_kwargs={"temperature":0.0, "max_length":256},
31
- huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
 
34
- # Helper functions
 
35
 
36
- def process_pdf(pdf_file):
37
- # Load text content
38
- loader = UnstructuredPDFLoader(pdf_file.name)
39
- docs = loader.load()
40
 
41
- # Basic text from PDF
42
- raw_text = "\n".join([d.page_content for d in docs])
 
 
43
 
44
- # Optionally extract images and caption them
45
- # Here, we simply caption any embedded images
46
- captions = []
47
- # (In a real pipeline, extract and save images separately)
48
- # For demo, we skip actual image files extraction
 
49
 
50
- # Combine text and captions
51
- combined = raw_text + "\n\n" + "\n".join(captions)
52
- return combined
53
 
 
 
 
 
54
 
55
- def build_index(text):
56
- global vector_store
57
- # Split into chunks
58
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
59
- chunks = splitter.split_text(text)
60
 
61
- # Create or update FAISS index
62
- vector_store = FAISS.from_texts(chunks, embeddings)
 
 
63
 
 
64
 
65
- def answer_query(query):
66
- qa = RetrievalQA.from_chain_type(
67
- llm=llm,
68
- chain_type="stuff",
69
- retriever=vector_store.as_retriever()
70
- )
71
- return qa.run(query)
72
 
73
  # Gradio UI
74
  with gr.Blocks() as demo:
75
- gr.Markdown("# Multimodal RAG QA App")
76
 
77
  with gr.Row():
78
- pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"] )
79
- question_input = gr.Textbox(label="Ask a question", placeholder="Enter your question here...")
80
-
81
- output = gr.Textbox(label="Answer", interactive=False)
82
 
83
- def on_submit(pdf, question):
84
- if pdf is not None:
85
- text = process_pdf(pdf)
86
- build_index(text)
87
- if not question:
88
- return "Please enter a question."
89
- return answer_query(question)
90
 
91
- submit_btn = gr.Button("Get Answer")
92
- submit_btn.click(on_submit, inputs=[pdf_input, question_input], outputs=output)
93
 
94
- if __name__ == "__main__":
95
- demo.launch()
 
1
  import os
 
 
2
  import gradio as gr
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain.vectorstores import FAISS
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
6
  from langchain.llms import HuggingFaceHub
7
+ from langchain.chains import RetrievalQA
8
+ from langchain.prompts import PromptTemplate
9
+ from unstructured.partition.pdf import partition_pdf
10
+ from unstructured.partition.utils.constants import PartitionStrategy
11
+ from huggingface_hub import InferenceClient
12
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Directories
15
+ PDF_DIR = "pdfs"
16
+ FIGURE_DIR = "figures"
17
+ os.makedirs(PDF_DIR, exist_ok=True)
18
+ os.makedirs(FIGURE_DIR, exist_ok=True)
19
+
20
+ # Embeddings and Model Setup
21
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
22
+ vector_store = FAISS.from_texts([], embedding_model)
23
+
24
+ llm = HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0.5, "max_length": 512})
25
+
26
+ template = """
27
+ Use the following context to answer the question. If the answer is unknown, say so.
28
+ Context: {context}
29
+ Question: {question}
30
+ Answer (3 sentences max):
31
+ """
32
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
33
+
34
+ qa_chain = RetrievalQA.from_chain_type(
35
+ llm=llm,
36
+ retriever=vector_store.as_retriever(),
37
+ chain_type_kwargs={"prompt": prompt}
38
  )
39
 
40
+ # Hugging Face Inference API Client (for image captioning, etc.)
41
+ vision_model = InferenceClient("Salesforce/blip-image-captioning-base")
42
 
43
+ def extract_image_text(file_path):
44
+ with Image.open(file_path) as img:
45
+ caption = vision_model.image_to_text(img)
46
+ return caption
47
 
48
+ def process_pdf(file):
49
+ pdf_path = os.path.join(PDF_DIR, file.name)
50
+ with open(pdf_path, "wb") as f:
51
+ f.write(file.read())
52
 
53
+ elements = partition_pdf(
54
+ pdf_path,
55
+ strategy=PartitionStrategy.HI_RES,
56
+ extract_image_block_types=["Image", "Table"],
57
+ extract_image_block_output_dir=FIGURE_DIR
58
+ )
59
 
60
+ texts = [el.text for el in elements if el.category not in ["Image", "Table"]]
 
 
61
 
62
+ for fig_file in os.listdir(FIGURE_DIR):
63
+ fig_path = os.path.join(FIGURE_DIR, fig_file)
64
+ caption = extract_image_text(fig_path)
65
+ texts.append(caption)
66
 
67
+ full_text = "\n\n".join(texts)
 
 
 
 
68
 
69
+ # Chunking
70
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
71
+ docs = splitter.split_text(full_text)
72
+ vector_store.add_texts(docs)
73
 
74
+ return f"Processed {file.name} with {len(docs)} text chunks."
75
 
76
+ def answer_query(question):
77
+ return qa_chain.run(question)
 
 
 
 
 
78
 
79
  # Gradio UI
80
  with gr.Blocks() as demo:
81
+ gr.Markdown("# 📄📷 Multimodal RAG with Hugging Face")
82
 
83
  with gr.Row():
84
+ file_input = gr.File(label="Upload PDF", type="file")
85
+ upload_btn = gr.Button("Process PDF")
86
+ status = gr.Textbox(label="Processing Status")
 
87
 
88
+ with gr.Row():
89
+ question = gr.Textbox(label="Ask a Question")
90
+ ask_btn = gr.Button("Get Answer")
91
+ answer_box = gr.Textbox(label="Answer")
 
 
 
92
 
93
+ upload_btn.click(fn=process_pdf, inputs=file_input, outputs=status)
94
+ ask_btn.click(fn=answer_query, inputs=question, outputs=answer_box)
95
 
96
+ demo.launch()