Spaces:
Sleeping
Sleeping
import os | |
import base64 | |
import fitz | |
from io import BytesIO | |
from PIL import Image | |
import requests | |
from llama_index.llms.nvidia import NVIDIA | |
import streamlit as st | |
from llama_index.core import Settings | |
from llama_index.core import VectorStoreIndex, StorageContext | |
from llama_index.core.node_parser import SentenceSplitter | |
from llama_index.vector_stores.milvus import MilvusVectorStore | |
from llama_index.embeddings.nvidia import NVIDIAEmbedding | |
from pptx import Presentation | |
import subprocess | |
from llama_index.core import Document | |
def set_environment_variables(): | |
"""Set necessary environment variables.""" | |
os.environ["NVIDIA_API_KEY"] = "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" #set API key | |
def get_b64_image_from_content(image_content): | |
"""Convert image content to base64 encoded string.""" | |
img = Image.open(BytesIO(image_content)) | |
if img.mode != 'RGB': | |
img = img.convert('RGB') | |
buffered = BytesIO() | |
img.save(buffered, format="JPEG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
def is_graph(image_content): | |
"""Determine if an image is a graph, plot, chart, or table.""" | |
res = describe_image(image_content) | |
return any(keyword in res.lower() for keyword in ["graph", "plot", "chart", "table"]) | |
def process_graph(image_content): | |
"""Process a graph image and generate a description.""" | |
deplot_description = process_graph_deplot(image_content) | |
mixtral = NVIDIA(model_name="meta/llama-3.1-70b-instruct") | |
response = mixtral.complete("Your responsibility is to explain charts. You are an expert in describing the responses of linearized tables into plain English text for LLMs to use. Explain the following linearized table. " + deplot_description) | |
return response.text | |
def describe_image(image_content): | |
"""Generate a description of an image using NVIDIA API.""" | |
image_b64 = get_b64_image_from_content(image_content) | |
invoke_url = "https://ai.api.nvidia.com/v1/vlm/nvidia/neva-22b" | |
api_key = os.getenv("NVIDIA_API_KEY") | |
if not api_key: | |
raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.") | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Accept": "application/json" | |
} | |
payload = { | |
"messages": [ | |
{ | |
"role": "user", | |
"content": f'Describe what you see in this image. <img src="data:image/png;base64,{image_b64}" />' | |
} | |
], | |
"max_tokens": 1024, | |
"temperature": 0.20, | |
"top_p": 0.70, | |
"seed": 0, | |
"stream": False | |
} | |
response = requests.post(invoke_url, headers=headers, json=payload) | |
return response.json()["choices"][0]['message']['content'] | |
def process_graph_deplot(image_content): | |
"""Process a graph image using NVIDIA's Deplot API.""" | |
invoke_url = "https://ai.api.nvidia.com/v1/vlm/google/deplot" | |
image_b64 = get_b64_image_from_content(image_content) | |
api_key = os.getenv("NVIDIA_API_KEY") | |
if not api_key: | |
raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.") | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Accept": "application/json" | |
} | |
payload = { | |
"messages": [ | |
{ | |
"role": "user", | |
"content": f'Generate underlying data table of the figure below: <img src="data:image/png;base64,{image_b64}" />' | |
} | |
], | |
"max_tokens": 1024, | |
"temperature": 0.20, | |
"top_p": 0.20, | |
"stream": False | |
} | |
response = requests.post(invoke_url, headers=headers, json=payload) | |
return response.json()["choices"][0]['message']['content'] | |
def extract_text_around_item(text_blocks, bbox, page_height, threshold_percentage=0.1): | |
"""Extract text above and below a given bounding box on a page.""" | |
before_text, after_text = "", "" | |
vertical_threshold_distance = page_height * threshold_percentage | |
horizontal_threshold_distance = bbox.width * threshold_percentage | |
for block in text_blocks: | |
block_bbox = fitz.Rect(block[:4]) | |
vertical_distance = min(abs(block_bbox.y1 - bbox.y0), abs(block_bbox.y0 - bbox.y1)) | |
horizontal_overlap = max(0, min(block_bbox.x1, bbox.x1) - max(block_bbox.x0, bbox.x0)) | |
if vertical_distance <= vertical_threshold_distance and horizontal_overlap >= -horizontal_threshold_distance: | |
if block_bbox.y1 < bbox.y0 and not before_text: | |
before_text = block[4] | |
elif block_bbox.y0 > bbox.y1 and not after_text: | |
after_text = block[4] | |
break | |
return before_text, after_text | |
def process_text_blocks(text_blocks, char_count_threshold=500): | |
"""Group text blocks based on a character count threshold.""" | |
current_group = [] | |
grouped_blocks = [] | |
current_char_count = 0 | |
for block in text_blocks: | |
if block[-1] == 0: # Check if the block is of text type | |
block_text = block[4] | |
block_char_count = len(block_text) | |
if current_char_count + block_char_count <= char_count_threshold: | |
current_group.append(block) | |
current_char_count += block_char_count | |
else: | |
if current_group: | |
grouped_content = "\n".join([b[4] for b in current_group]) | |
grouped_blocks.append((current_group[0], grouped_content)) | |
current_group = [block] | |
current_char_count = block_char_count | |
# Append the last group | |
if current_group: | |
grouped_content = "\n".join([b[4] for b in current_group]) | |
grouped_blocks.append((current_group[0], grouped_content)) | |
return grouped_blocks | |
def save_uploaded_file(uploaded_file): | |
"""Save an uploaded file to a temporary directory.""" | |
temp_dir = os.path.join(os.getcwd(), "vectorstore", "ppt_references", "tmp") | |
os.makedirs(temp_dir, exist_ok=True) | |
temp_file_path = os.path.join(temp_dir, uploaded_file.name) | |
with open(temp_file_path, "wb") as temp_file: | |
temp_file.write(uploaded_file.read()) | |
return temp_file_path | |
# 2ème fichier du code | |
def get_pdf_documents(pdf_file): | |
"""Process a PDF file and extract text, tables, and images.""" | |
all_pdf_documents = [] | |
ongoing_tables = {} | |
try: | |
f = fitz.open(stream=pdf_file.read(), filetype="pdf") | |
except Exception as e: | |
print(f"Error opening or processing the PDF file: {e}") | |
return [] | |
for i in range(len(f)): | |
page = f[i] | |
text_blocks = [block for block in page.get_text("blocks", sort=True) | |
if block[-1] == 0 and not (block[1] < page.rect.height * 0.1 or block[3] > page.rect.height * 0.9)] | |
grouped_text_blocks = process_text_blocks(text_blocks) | |
table_docs, table_bboxes, ongoing_tables = parse_all_tables(pdf_file.name, page, i, text_blocks, ongoing_tables) | |
all_pdf_documents.extend(table_docs) | |
image_docs = parse_all_images(pdf_file.name, page, i, text_blocks) | |
all_pdf_documents.extend(image_docs) | |
for text_block_ctr, (heading_block, content) in enumerate(grouped_text_blocks, 1): | |
heading_bbox = fitz.Rect(heading_block[:4]) | |
if not any(heading_bbox.intersects(table_bbox) for table_bbox in table_bboxes): | |
bbox = {"x1": heading_block[0], "y1": heading_block[1], "x2": heading_block[2], "x3": heading_block[3]} | |
text_doc = Document( | |
text=f"{heading_block[4]}\n{content}", | |
metadata={ | |
**bbox, | |
"type": "text", | |
"page_num": i, | |
"source": f"{pdf_file.name[:-4]}-page{i}-block{text_block_ctr}" | |
}, | |
id_=f"{pdf_file.name[:-4]}-page{i}-block{text_block_ctr}" | |
) | |
all_pdf_documents.append(text_doc) | |
f.close() | |
return all_pdf_documents | |
def parse_all_tables(filename, page, pagenum, text_blocks, ongoing_tables): | |
"""Extract tables from a PDF page.""" | |
table_docs = [] | |
table_bboxes = [] | |
try: | |
tables = page.find_tables(horizontal_strategy="lines_strict", vertical_strategy="lines_strict") | |
for tab in tables: | |
if not tab.header.external: | |
pandas_df = tab.to_pandas() | |
tablerefdir = os.path.join(os.getcwd(), "vectorstore/table_references") | |
os.makedirs(tablerefdir, exist_ok=True) | |
df_xlsx_path = os.path.join(tablerefdir, f"table{len(table_docs)+1}-page{pagenum}.xlsx") | |
pandas_df.to_excel(df_xlsx_path) | |
bbox = fitz.Rect(tab.bbox) | |
table_bboxes.append(bbox) | |
before_text, after_text = extract_text_around_item(text_blocks, bbox, page.rect.height) | |
table_img = page.get_pixmap(clip=bbox) | |
table_img_path = os.path.join(tablerefdir, f"table{len(table_docs)+1}-page{pagenum}.jpg") | |
table_img.save(table_img_path) | |
description = process_graph(table_img.tobytes()) | |
caption = before_text.replace("\n", " ") + description + after_text.replace("\n", " ") | |
if before_text == "" and after_text == "": | |
caption = " ".join(tab.header.names) | |
table_metadata = { | |
"source": f"{filename[:-4]}-page{pagenum}-table{len(table_docs)+1}", | |
"dataframe": df_xlsx_path, | |
"image": table_img_path, | |
"caption": caption, | |
"type": "table", | |
"page_num": pagenum | |
} | |
all_cols = ", ".join(list(pandas_df.columns.values)) | |
doc = Document(text=f"This is a table with the caption: {caption}\nThe columns are {all_cols}", metadata=table_metadata) | |
table_docs.append(doc) | |
except Exception as e: | |
print(f"Error during table extraction: {e}") | |
return table_docs, table_bboxes, ongoing_tables | |
def parse_all_images(filename, page, pagenum, text_blocks): | |
"""Extract images from a PDF page.""" | |
image_docs = [] | |
image_info_list = page.get_image_info(xrefs=True) | |
page_rect = page.rect | |
for image_info in image_info_list: | |
xref = image_info['xref'] | |
if xref == 0: | |
continue | |
img_bbox = fitz.Rect(image_info['bbox']) | |
if img_bbox.width < page_rect.width / 20 or img_bbox.height < page_rect.height / 20: | |
continue | |
extracted_image = page.parent.extract_image(xref) | |
image_data = extracted_image["image"] | |
imgrefpath = os.path.join(os.getcwd(), "vectorstore/image_references") | |
os.makedirs(imgrefpath, exist_ok=True) | |
image_path = os.path.join(imgrefpath, f"image{xref}-page{pagenum}.png") | |
with open(image_path, "wb") as img_file: | |
img_file.write(image_data) | |
before_text, after_text = extract_text_around_item(text_blocks, img_bbox, page.rect.height) | |
if before_text == "" and after_text == "": | |
continue | |
image_description = " " | |
if is_graph(image_data): | |
image_description = process_graph(image_data) | |
caption = before_text.replace("\n", " ") + image_description + after_text.replace("\n", " ") | |
image_metadata = { | |
"source": f"{filename[:-4]}-page{pagenum}-image{xref}", | |
"image": image_path, | |
"caption": caption, | |
"type": "image", | |
"page_num": pagenum | |
} | |
image_docs.append(Document(text="This is an image with the caption: " + caption, metadata=image_metadata)) | |
return image_docs | |
def process_ppt_file(ppt_path): | |
"""Process a PowerPoint file.""" | |
pdf_path = convert_ppt_to_pdf(ppt_path) | |
images_data = convert_pdf_to_images(pdf_path) | |
slide_texts = extract_text_and_notes_from_ppt(ppt_path) | |
processed_data = [] | |
for (image_path, page_num), (slide_text, notes) in zip(images_data, slide_texts): | |
if notes: | |
notes = "\n\nThe speaker notes for this slide are: " + notes | |
with open(image_path, 'rb') as image_file: | |
image_content = image_file.read() | |
image_description = " " | |
if is_graph(image_content): | |
image_description = process_graph(image_content) | |
image_metadata = { | |
"source": f"{os.path.basename(ppt_path)}", | |
"image": image_path, | |
"caption": slide_text + image_description + notes, | |
"type": "image", | |
"page_num": page_num | |
} | |
processed_data.append(Document(text="This is a slide with the text: " + slide_text + image_description, metadata=image_metadata)) | |
return processed_data | |
def convert_ppt_to_pdf(ppt_path): | |
"""Convert a PowerPoint file to PDF using LibreOffice.""" | |
base_name = os.path.basename(ppt_path) | |
ppt_name_without_ext = os.path.splitext(base_name)[0].replace(' ', '_') | |
new_dir_path = os.path.abspath("vectorstore/ppt_references") | |
os.makedirs(new_dir_path, exist_ok=True) | |
pdf_path = os.path.join(new_dir_path, f"{ppt_name_without_ext}.pdf") | |
command = ['libreoffice', '--headless', '--convert-to', 'pdf', '--outdir', new_dir_path, ppt_path] | |
subprocess.run(command, check=True) | |
return pdf_path | |
def convert_pdf_to_images(pdf_path): | |
"""Convert a PDF file to a series of images using PyMuPDF.""" | |
doc = fitz.open(pdf_path) | |
base_name = os.path.basename(pdf_path) | |
pdf_name_without_ext = os.path.splitext(base_name)[0].replace(' ', '_') | |
new_dir_path = os.path.join(os.getcwd(), "vectorstore/ppt_references") | |
os.makedirs(new_dir_path, exist_ok=True) | |
image_paths = [] | |
for page_num in range(len(doc)): | |
page = doc.load_page(page_num) | |
pix = page.get_pixmap() | |
output_image_path = os.path.join(new_dir_path, f"{pdf_name_without_ext}_{page_num:04d}.png") | |
pix.save(output_image_path) | |
image_paths.append((output_image_path, page_num)) | |
doc.close() | |
return image_paths | |
def extract_text_and_notes_from_ppt(ppt_path): | |
"""Extract text and notes from a PowerPoint file.""" | |
prs = Presentation(ppt_path) | |
text_and_notes = [] | |
for slide in prs.slides: | |
slide_text = ' '.join([shape.text for shape in slide.shapes if hasattr(shape, "text")]) | |
try: | |
notes = slide.notes_slide.notes_text_frame.text if slide.notes_slide else '' | |
except: | |
notes = '' | |
text_and_notes.append((slide_text, notes)) | |
return text_and_notes | |
def load_multimodal_data(files): | |
"""Load and process multiple file types.""" | |
documents = [] | |
for file in files: | |
file_extension = os.path.splitext(file.name.lower())[1] | |
if file_extension in ('.png', '.jpg', '.jpeg'): | |
image_content = file.read() | |
image_text = describe_image(image_content) | |
doc = Document(text=image_text, metadata={"source": file.name, "type": "image"}) | |
documents.append(doc) | |
elif file_extension == '.pdf': | |
try: | |
pdf_documents = get_pdf_documents(file) | |
documents.extend(pdf_documents) | |
except Exception as e: | |
print(f"Error processing PDF {file.name}: {e}") | |
elif file_extension in ('.ppt', '.pptx'): | |
try: | |
ppt_documents = process_ppt_file(save_uploaded_file(file)) | |
documents.extend(ppt_documents) | |
except Exception as e: | |
print(f"Error processing PPT {file.name}: {e}") | |
else: | |
text = file.read().decode("utf-8") | |
doc = Document(text=text, metadata={"source": file.name, "type": "text"}) | |
documents.append(doc) | |
return documents | |
def load_data_from_directory(directory): | |
"""Load and process multiple file types from a directory.""" | |
documents = [] | |
for filename in os.listdir(directory): | |
filepath = os.path.join(directory, filename) | |
file_extension = os.path.splitext(filename.lower())[1] | |
print(filename) | |
if file_extension in ('.png', '.jpg', '.jpeg'): | |
with open(filepath, "rb") as image_file: | |
image_content = image_file.read() | |
image_text = describe_image(image_content) | |
doc = Document(text=image_text, metadata={"source": filename, "type": "image"}) | |
print(doc) | |
documents.append(doc) | |
elif file_extension == '.pdf': | |
with open(filepath, "rb") as pdf_file: | |
try: | |
pdf_documents = get_pdf_documents(pdf_file) | |
documents.extend(pdf_documents) | |
except Exception as e: | |
print(f"Error processing PDF {filename}: {e}") | |
elif file_extension in ('.ppt', '.pptx'): | |
try: | |
ppt_documents = process_ppt_file(filepath) | |
documents.extend(ppt_documents) | |
print(ppt_documents) | |
except Exception as e: | |
print(f"Error processing PPT {filename}: {e}") | |
else: | |
with open(filepath, "r", encoding="utf-8") as text_file: | |
text = text_file.read() | |
doc = Document(text=text, metadata={"source": filename, "type": "text"}) | |
documents.append(doc) | |
return documents | |
# 3ème fichier | |
# Set up the page configuration | |
st.set_page_config(layout="wide") | |
# Initialize settings | |
def initialize_settings(): | |
Settings.embed_model = NVIDIAEmbedding(model="nvidia/nv-embedqa-e5-v5", truncate="END") | |
Settings.llm = NVIDIA(model="meta/llama-3.1-70b-instruct") | |
Settings.text_splitter = SentenceSplitter(chunk_size=600) | |
# Create index from documents | |
def create_index(documents): | |
vector_store = MilvusVectorStore( | |
host = "127.0.0.1", | |
port = 19530, | |
dim = 1024 | |
) | |
# vector_store = MilvusVectorStore(uri="./milvus_demo.db", dim=1024, overwrite=True) #For CPU only vector store | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
return VectorStoreIndex.from_documents(documents, storage_context=storage_context) | |
# Main function to run the Streamlit app | |
def main(): | |
set_environment_variables() | |
initialize_settings() | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
st.title("Multimodal RAG") | |
input_method = st.radio("Choose input method:", ("Upload Files", "Enter Directory Path")) | |
if input_method == "Upload Files": | |
uploaded_files = st.file_uploader("Drag and drop files here", accept_multiple_files=True) | |
if uploaded_files and st.button("Process Files"): | |
with st.spinner("Processing files..."): | |
documents = load_multimodal_data(uploaded_files) | |
st.session_state['index'] = create_index(documents) | |
st.session_state['history'] = [] | |
st.success("Files processed and index created!") | |
else: | |
directory_path = st.text_input("Enter directory path:") | |
if directory_path and st.button("Process Directory"): | |
if os.path.isdir(directory_path): | |
with st.spinner("Processing directory..."): | |
documents = load_data_from_directory(directory_path) | |
st.session_state['index'] = create_index(documents) | |
st.session_state['history'] = [] | |
st.success("Directory processed and index created!") | |
else: | |
st.error("Invalid directory path. Please enter a valid path.") | |
with col2: | |
if 'index' in st.session_state: | |
st.title("Chat") | |
if 'history' not in st.session_state: | |
st.session_state['history'] = [] | |
query_engine = st.session_state['index'].as_query_engine(similarity_top_k=5, streaming=True) | |
user_input = st.chat_input("Enter your query:") | |
# Display chat messages | |
chat_container = st.container() | |
with chat_container: | |
for message in st.session_state['history']: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if user_input: | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
st.session_state['history'].append({"role": "user", "content": user_input}) | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
full_response = "" | |
response = query_engine.query(user_input) | |
for token in response.response_gen: | |
full_response += token | |
message_placeholder.markdown(full_response + "▌") | |
message_placeholder.markdown(full_response) | |
st.session_state['history'].append({"role": "assistant", "content": full_response}) | |
# Add a clear button | |
if st.button("Clear Chat"): | |
st.session_state['history'] = [] | |
st.rerun() | |
if __name__ == "__main__": | |
main() | |