Muzammil6376 commited on
Commit
3fdd093
Β·
verified Β·
1 Parent(s): 40696fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -58
app.py CHANGED
@@ -1,96 +1,121 @@
 
1
  import os
 
 
2
  import gradio as gr
3
- from langchain.embeddings import HuggingFaceEmbeddings
4
- from langchain.vectorstores import FAISS
 
 
 
 
 
 
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain.llms import HuggingFaceHub
7
  from langchain.chains import RetrievalQA
8
  from langchain.prompts import PromptTemplate
 
9
  from unstructured.partition.pdf import partition_pdf
10
  from unstructured.partition.utils.constants import PartitionStrategy
11
- from huggingface_hub import InferenceClient
12
- from PIL import Image
13
 
14
- # Directories
15
- PDF_DIR = "pdfs"
16
- FIGURE_DIR = "figures"
17
- os.makedirs(PDF_DIR, exist_ok=True)
18
- os.makedirs(FIGURE_DIR, exist_ok=True)
19
 
20
- # Embeddings and Model Setup
21
  embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
22
- vector_store = FAISS.from_texts([], embedding_model)
23
 
24
- llm = HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0.5, "max_length": 512})
 
 
 
 
25
 
26
- template = """
27
- Use the following context to answer the question. If the answer is unknown, say so.
 
28
  Context: {context}
29
  Question: {question}
30
- Answer (3 sentences max):
31
  """
32
- prompt = PromptTemplate(template=template, input_variables=["context", "question"])
 
 
 
 
 
 
 
33
 
34
- qa_chain = RetrievalQA.from_chain_type(
35
- llm=llm,
36
- retriever=vector_store.as_retriever(),
37
- chain_type_kwargs={"prompt": prompt}
38
- )
39
 
40
- # Hugging Face Inference API Client (for image captioning, etc.)
41
- vision_model = InferenceClient("Salesforce/blip-image-captioning-base")
 
 
42
 
43
- def extract_image_text(file_path):
44
- with Image.open(file_path) as img:
45
- caption = vision_model.image_to_text(img)
46
- return caption
47
 
48
- def process_pdf(file):
49
- pdf_path = os.path.join(PDF_DIR, file.name)
50
- with open(pdf_path, "wb") as f:
51
- f.write(file.read())
52
 
53
- elements = partition_pdf(
54
- pdf_path,
 
 
 
 
 
 
55
  strategy=PartitionStrategy.HI_RES,
56
  extract_image_block_types=["Image", "Table"],
57
- extract_image_block_output_dir=FIGURE_DIR
58
  )
59
 
60
- texts = [el.text for el in elements if el.category not in ["Image", "Table"]]
61
-
62
- for fig_file in os.listdir(FIGURE_DIR):
63
- fig_path = os.path.join(FIGURE_DIR, fig_file)
64
- caption = extract_image_text(fig_path)
65
- texts.append(caption)
66
 
67
- full_text = "\n\n".join(texts)
 
 
68
 
69
- # Chunking
70
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
71
- docs = splitter.split_text(full_text)
72
- vector_store.add_texts(docs)
73
 
74
- return f"Processed {file.name} with {len(docs)} text chunks."
 
 
 
 
 
 
 
75
 
76
- def answer_query(question):
 
 
 
77
  return qa_chain.run(question)
78
 
79
- # Gradio UI
 
80
  with gr.Blocks() as demo:
81
- gr.Markdown("# πŸ“„πŸ“· Multimodal RAG with Hugging Face")
82
 
83
  with gr.Row():
84
- file_input = gr.File(label="Upload PDF", type="file")
85
- upload_btn = gr.Button("Process PDF")
86
- status = gr.Textbox(label="Processing Status")
87
 
88
  with gr.Row():
89
- question = gr.Textbox(label="Ask a Question")
90
- ask_btn = gr.Button("Get Answer")
91
- answer_box = gr.Textbox(label="Answer")
92
 
93
- upload_btn.click(fn=process_pdf, inputs=file_input, outputs=status)
94
- ask_btn.click(fn=answer_query, inputs=question, outputs=answer_box)
95
 
96
- demo.launch()
 
 
1
+ # app.py
2
  import os
3
+ from pathlib import Path
4
+
5
  import gradio as gr
6
+ from PIL import Image
7
+ from huggingface_hub import InferenceClient
8
+
9
+ # βœ… Use community packages to avoid deprecation warnings
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_community.llms import HuggingFaceHub
13
+
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
15
  from langchain.chains import RetrievalQA
16
  from langchain.prompts import PromptTemplate
17
+
18
  from unstructured.partition.pdf import partition_pdf
19
  from unstructured.partition.utils.constants import PartitionStrategy
 
 
20
 
21
+ # β€”β€”β€”β€”β€” Config & Folders β€”β€”β€”β€”β€”
22
+ PDF_DIR = Path("pdfs")
23
+ FIG_DIR = Path("figures")
24
+ PDF_DIR.mkdir(exist_ok=True)
25
+ FIG_DIR.mkdir(exist_ok=True)
26
 
27
+ # β€”β€”β€”β€”β€” Embeddings & LLM Setup β€”β€”β€”β€”β€”
28
  embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
29
 
30
+ # LLM via Hugging Face Inference API
31
+ llm = HuggingFaceHub(
32
+ repo_id="google/flan-t5-base",
33
+ model_kwargs={"temperature": 0.5, "max_length": 512}
34
+ )
35
 
36
+ # Prompt
37
+ TEMPLATE = """
38
+ Use the following context to answer the question. If unknown, say so.
39
  Context: {context}
40
  Question: {question}
41
+ Answer (up to 3 sentences):
42
  """
43
+ prompt = PromptTemplate(template=TEMPLATE, input_variables=["context", "question"])
44
+
45
+ # Inference client for image captioning
46
+ vision_client = InferenceClient("Salesforce/blip-image-captioning-base")
47
+
48
+ # Globals (will set after processing)
49
+ vector_store = None
50
+ qa_chain = None
51
 
 
 
 
 
 
52
 
53
+ def extract_image_caption(path: str) -> str:
54
+ """Return an autogenerated caption for an image file."""
55
+ with Image.open(path) as img:
56
+ return vision_client.image_to_text(img)
57
 
 
 
 
 
58
 
59
+ def process_pdf(pdf_file) -> str:
60
+ """Save, parse, chunk, embed & index a PDF (text + images)."""
61
+ global vector_store, qa_chain
 
62
 
63
+ # 1️⃣ Save PDF
64
+ out_path = PDF_DIR / pdf_file.name
65
+ with open(out_path, "wb") as f:
66
+ f.write(pdf_file.read())
67
+
68
+ # 2️⃣ Partition into text + image blocks
69
+ elems = partition_pdf(
70
+ str(out_path),
71
  strategy=PartitionStrategy.HI_RES,
72
  extract_image_block_types=["Image", "Table"],
73
+ extract_image_block_output_dir=str(FIG_DIR),
74
  )
75
 
76
+ # 3️⃣ Collect text
77
+ texts = [el.text for el in elems if el.category not in ("Image", "Table")]
 
 
 
 
78
 
79
+ # 4️⃣ Caption each image
80
+ for img_file in FIG_DIR.iterdir():
81
+ texts.append(extract_image_caption(str(img_file)))
82
 
83
+ # 5️⃣ Split & index
84
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
85
+ docs = splitter.split_text("\n\n".join(texts))
 
86
 
87
+ vector_store = FAISS.from_texts(docs, embedding_model)
88
+ qa_chain = RetrievalQA.from_chain_type(
89
+ llm=llm,
90
+ retriever=vector_store.as_retriever(),
91
+ chain_type_kwargs={"prompt": prompt},
92
+ )
93
+
94
+ return f"βœ… Processed `{pdf_file.name}` into {len(docs)} chunks."
95
 
96
+
97
+ def answer_query(question: str) -> str:
98
+ if qa_chain is None:
99
+ return "❗ Please upload and process a PDF first."
100
  return qa_chain.run(question)
101
 
102
+
103
+ # β€”β€”β€”β€”β€” Gradio UI β€”β€”β€”β€”β€”
104
  with gr.Blocks() as demo:
105
+ gr.Markdown("## πŸ“„πŸ“· Multimodal RAG β€” Hugging Face Spaces")
106
 
107
  with gr.Row():
108
+ pdf_in = gr.File(label="Upload PDF", type="file")
109
+ btn_proc = gr.Button("Process PDF")
110
+ status = gr.Textbox(label="Status")
111
 
112
  with gr.Row():
113
+ q_in = gr.Textbox(label="Your Question")
114
+ btn_ask = gr.Button("Ask")
115
+ ans_out = gr.Textbox(label="Answer")
116
 
117
+ btn_proc.click(fn=process_pdf, inputs=pdf_in, outputs=status)
118
+ btn_ask.click(fn=answer_query, inputs=q_in, outputs=ans_out)
119
 
120
+ if __name__ == "__main__":
121
+ demo.launch()