import os import base64 import fitz from io import BytesIO from PIL import Image import requests from llama_index.llms.nvidia import NVIDIA import streamlit as st from llama_index.core import Settings from llama_index.core import VectorStoreIndex, StorageContext from llama_index.core.node_parser import SentenceSplitter from llama_index.vector_stores.milvus import MilvusVectorStore from llama_index.embeddings.nvidia import NVIDIAEmbedding from pptx import Presentation import subprocess from llama_index.core import Document def set_environment_variables(): """Set necessary environment variables.""" os.environ["NVIDIA_API_KEY"] = "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" #set API key def get_b64_image_from_content(image_content): """Convert image content to base64 encoded string.""" img = Image.open(BytesIO(image_content)) if img.mode != 'RGB': img = img.convert('RGB') buffered = BytesIO() img.save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def is_graph(image_content): """Determine if an image is a graph, plot, chart, or table.""" res = describe_image(image_content) return any(keyword in res.lower() for keyword in ["graph", "plot", "chart", "table"]) def process_graph(image_content): """Process a graph image and generate a description.""" deplot_description = process_graph_deplot(image_content) mixtral = NVIDIA(model_name="meta/llama-3.1-70b-instruct") response = mixtral.complete("Your responsibility is to explain charts. You are an expert in describing the responses of linearized tables into plain English text for LLMs to use. Explain the following linearized table. " + deplot_description) return response.text def describe_image(image_content): """Generate a description of an image using NVIDIA API.""" image_b64 = get_b64_image_from_content(image_content) invoke_url = "https://ai.api.nvidia.com/v1/vlm/nvidia/neva-22b" api_key = os.getenv("NVIDIA_API_KEY") if not api_key: raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.") headers = { "Authorization": f"Bearer {api_key}", "Accept": "application/json" } payload = { "messages": [ { "role": "user", "content": f'Describe what you see in this image. ' } ], "max_tokens": 1024, "temperature": 0.20, "top_p": 0.70, "seed": 0, "stream": False } response = requests.post(invoke_url, headers=headers, json=payload) return response.json()["choices"][0]['message']['content'] def process_graph_deplot(image_content): """Process a graph image using NVIDIA's Deplot API.""" invoke_url = "https://ai.api.nvidia.com/v1/vlm/google/deplot" image_b64 = get_b64_image_from_content(image_content) api_key = os.getenv("NVIDIA_API_KEY") if not api_key: raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.") headers = { "Authorization": f"Bearer {api_key}", "Accept": "application/json" } payload = { "messages": [ { "role": "user", "content": f'Generate underlying data table of the figure below: ' } ], "max_tokens": 1024, "temperature": 0.20, "top_p": 0.20, "stream": False } response = requests.post(invoke_url, headers=headers, json=payload) return response.json()["choices"][0]['message']['content'] def extract_text_around_item(text_blocks, bbox, page_height, threshold_percentage=0.1): """Extract text above and below a given bounding box on a page.""" before_text, after_text = "", "" vertical_threshold_distance = page_height * threshold_percentage horizontal_threshold_distance = bbox.width * threshold_percentage for block in text_blocks: block_bbox = fitz.Rect(block[:4]) vertical_distance = min(abs(block_bbox.y1 - bbox.y0), abs(block_bbox.y0 - bbox.y1)) horizontal_overlap = max(0, min(block_bbox.x1, bbox.x1) - max(block_bbox.x0, bbox.x0)) if vertical_distance <= vertical_threshold_distance and horizontal_overlap >= -horizontal_threshold_distance: if block_bbox.y1 < bbox.y0 and not before_text: before_text = block[4] elif block_bbox.y0 > bbox.y1 and not after_text: after_text = block[4] break return before_text, after_text def process_text_blocks(text_blocks, char_count_threshold=500): """Group text blocks based on a character count threshold.""" current_group = [] grouped_blocks = [] current_char_count = 0 for block in text_blocks: if block[-1] == 0: # Check if the block is of text type block_text = block[4] block_char_count = len(block_text) if current_char_count + block_char_count <= char_count_threshold: current_group.append(block) current_char_count += block_char_count else: if current_group: grouped_content = "\n".join([b[4] for b in current_group]) grouped_blocks.append((current_group[0], grouped_content)) current_group = [block] current_char_count = block_char_count # Append the last group if current_group: grouped_content = "\n".join([b[4] for b in current_group]) grouped_blocks.append((current_group[0], grouped_content)) return grouped_blocks def save_uploaded_file(uploaded_file): """Save an uploaded file to a temporary directory.""" temp_dir = os.path.join(os.getcwd(), "vectorstore", "ppt_references", "tmp") os.makedirs(temp_dir, exist_ok=True) temp_file_path = os.path.join(temp_dir, uploaded_file.name) with open(temp_file_path, "wb") as temp_file: temp_file.write(uploaded_file.read()) return temp_file_path # 2ème fichier du code def get_pdf_documents(pdf_file): """Process a PDF file and extract text, tables, and images.""" all_pdf_documents = [] ongoing_tables = {} try: f = fitz.open(stream=pdf_file.read(), filetype="pdf") except Exception as e: print(f"Error opening or processing the PDF file: {e}") return [] for i in range(len(f)): page = f[i] text_blocks = [block for block in page.get_text("blocks", sort=True) if block[-1] == 0 and not (block[1] < page.rect.height * 0.1 or block[3] > page.rect.height * 0.9)] grouped_text_blocks = process_text_blocks(text_blocks) table_docs, table_bboxes, ongoing_tables = parse_all_tables(pdf_file.name, page, i, text_blocks, ongoing_tables) all_pdf_documents.extend(table_docs) image_docs = parse_all_images(pdf_file.name, page, i, text_blocks) all_pdf_documents.extend(image_docs) for text_block_ctr, (heading_block, content) in enumerate(grouped_text_blocks, 1): heading_bbox = fitz.Rect(heading_block[:4]) if not any(heading_bbox.intersects(table_bbox) for table_bbox in table_bboxes): bbox = {"x1": heading_block[0], "y1": heading_block[1], "x2": heading_block[2], "x3": heading_block[3]} text_doc = Document( text=f"{heading_block[4]}\n{content}", metadata={ **bbox, "type": "text", "page_num": i, "source": f"{pdf_file.name[:-4]}-page{i}-block{text_block_ctr}" }, id_=f"{pdf_file.name[:-4]}-page{i}-block{text_block_ctr}" ) all_pdf_documents.append(text_doc) f.close() return all_pdf_documents def parse_all_tables(filename, page, pagenum, text_blocks, ongoing_tables): """Extract tables from a PDF page.""" table_docs = [] table_bboxes = [] try: tables = page.find_tables(horizontal_strategy="lines_strict", vertical_strategy="lines_strict") for tab in tables: if not tab.header.external: pandas_df = tab.to_pandas() tablerefdir = os.path.join(os.getcwd(), "vectorstore/table_references") os.makedirs(tablerefdir, exist_ok=True) df_xlsx_path = os.path.join(tablerefdir, f"table{len(table_docs)+1}-page{pagenum}.xlsx") pandas_df.to_excel(df_xlsx_path) bbox = fitz.Rect(tab.bbox) table_bboxes.append(bbox) before_text, after_text = extract_text_around_item(text_blocks, bbox, page.rect.height) table_img = page.get_pixmap(clip=bbox) table_img_path = os.path.join(tablerefdir, f"table{len(table_docs)+1}-page{pagenum}.jpg") table_img.save(table_img_path) description = process_graph(table_img.tobytes()) caption = before_text.replace("\n", " ") + description + after_text.replace("\n", " ") if before_text == "" and after_text == "": caption = " ".join(tab.header.names) table_metadata = { "source": f"{filename[:-4]}-page{pagenum}-table{len(table_docs)+1}", "dataframe": df_xlsx_path, "image": table_img_path, "caption": caption, "type": "table", "page_num": pagenum } all_cols = ", ".join(list(pandas_df.columns.values)) doc = Document(text=f"This is a table with the caption: {caption}\nThe columns are {all_cols}", metadata=table_metadata) table_docs.append(doc) except Exception as e: print(f"Error during table extraction: {e}") return table_docs, table_bboxes, ongoing_tables def parse_all_images(filename, page, pagenum, text_blocks): """Extract images from a PDF page.""" image_docs = [] image_info_list = page.get_image_info(xrefs=True) page_rect = page.rect for image_info in image_info_list: xref = image_info['xref'] if xref == 0: continue img_bbox = fitz.Rect(image_info['bbox']) if img_bbox.width < page_rect.width / 20 or img_bbox.height < page_rect.height / 20: continue extracted_image = page.parent.extract_image(xref) image_data = extracted_image["image"] imgrefpath = os.path.join(os.getcwd(), "vectorstore/image_references") os.makedirs(imgrefpath, exist_ok=True) image_path = os.path.join(imgrefpath, f"image{xref}-page{pagenum}.png") with open(image_path, "wb") as img_file: img_file.write(image_data) before_text, after_text = extract_text_around_item(text_blocks, img_bbox, page.rect.height) if before_text == "" and after_text == "": continue image_description = " " if is_graph(image_data): image_description = process_graph(image_data) caption = before_text.replace("\n", " ") + image_description + after_text.replace("\n", " ") image_metadata = { "source": f"{filename[:-4]}-page{pagenum}-image{xref}", "image": image_path, "caption": caption, "type": "image", "page_num": pagenum } image_docs.append(Document(text="This is an image with the caption: " + caption, metadata=image_metadata)) return image_docs def process_ppt_file(ppt_path): """Process a PowerPoint file.""" pdf_path = convert_ppt_to_pdf(ppt_path) images_data = convert_pdf_to_images(pdf_path) slide_texts = extract_text_and_notes_from_ppt(ppt_path) processed_data = [] for (image_path, page_num), (slide_text, notes) in zip(images_data, slide_texts): if notes: notes = "\n\nThe speaker notes for this slide are: " + notes with open(image_path, 'rb') as image_file: image_content = image_file.read() image_description = " " if is_graph(image_content): image_description = process_graph(image_content) image_metadata = { "source": f"{os.path.basename(ppt_path)}", "image": image_path, "caption": slide_text + image_description + notes, "type": "image", "page_num": page_num } processed_data.append(Document(text="This is a slide with the text: " + slide_text + image_description, metadata=image_metadata)) return processed_data def convert_ppt_to_pdf(ppt_path): """Convert a PowerPoint file to PDF using LibreOffice.""" base_name = os.path.basename(ppt_path) ppt_name_without_ext = os.path.splitext(base_name)[0].replace(' ', '_') new_dir_path = os.path.abspath("vectorstore/ppt_references") os.makedirs(new_dir_path, exist_ok=True) pdf_path = os.path.join(new_dir_path, f"{ppt_name_without_ext}.pdf") command = ['libreoffice', '--headless', '--convert-to', 'pdf', '--outdir', new_dir_path, ppt_path] subprocess.run(command, check=True) return pdf_path def convert_pdf_to_images(pdf_path): """Convert a PDF file to a series of images using PyMuPDF.""" doc = fitz.open(pdf_path) base_name = os.path.basename(pdf_path) pdf_name_without_ext = os.path.splitext(base_name)[0].replace(' ', '_') new_dir_path = os.path.join(os.getcwd(), "vectorstore/ppt_references") os.makedirs(new_dir_path, exist_ok=True) image_paths = [] for page_num in range(len(doc)): page = doc.load_page(page_num) pix = page.get_pixmap() output_image_path = os.path.join(new_dir_path, f"{pdf_name_without_ext}_{page_num:04d}.png") pix.save(output_image_path) image_paths.append((output_image_path, page_num)) doc.close() return image_paths def extract_text_and_notes_from_ppt(ppt_path): """Extract text and notes from a PowerPoint file.""" prs = Presentation(ppt_path) text_and_notes = [] for slide in prs.slides: slide_text = ' '.join([shape.text for shape in slide.shapes if hasattr(shape, "text")]) try: notes = slide.notes_slide.notes_text_frame.text if slide.notes_slide else '' except: notes = '' text_and_notes.append((slide_text, notes)) return text_and_notes def load_multimodal_data(files): """Load and process multiple file types.""" documents = [] for file in files: file_extension = os.path.splitext(file.name.lower())[1] if file_extension in ('.png', '.jpg', '.jpeg'): image_content = file.read() image_text = describe_image(image_content) doc = Document(text=image_text, metadata={"source": file.name, "type": "image"}) documents.append(doc) elif file_extension == '.pdf': try: pdf_documents = get_pdf_documents(file) documents.extend(pdf_documents) except Exception as e: print(f"Error processing PDF {file.name}: {e}") elif file_extension in ('.ppt', '.pptx'): try: ppt_documents = process_ppt_file(save_uploaded_file(file)) documents.extend(ppt_documents) except Exception as e: print(f"Error processing PPT {file.name}: {e}") else: text = file.read().decode("utf-8") doc = Document(text=text, metadata={"source": file.name, "type": "text"}) documents.append(doc) return documents def load_data_from_directory(directory): """Load and process multiple file types from a directory.""" documents = [] for filename in os.listdir(directory): filepath = os.path.join(directory, filename) file_extension = os.path.splitext(filename.lower())[1] print(filename) if file_extension in ('.png', '.jpg', '.jpeg'): with open(filepath, "rb") as image_file: image_content = image_file.read() image_text = describe_image(image_content) doc = Document(text=image_text, metadata={"source": filename, "type": "image"}) print(doc) documents.append(doc) elif file_extension == '.pdf': with open(filepath, "rb") as pdf_file: try: pdf_documents = get_pdf_documents(pdf_file) documents.extend(pdf_documents) except Exception as e: print(f"Error processing PDF {filename}: {e}") elif file_extension in ('.ppt', '.pptx'): try: ppt_documents = process_ppt_file(filepath) documents.extend(ppt_documents) print(ppt_documents) except Exception as e: print(f"Error processing PPT {filename}: {e}") else: with open(filepath, "r", encoding="utf-8") as text_file: text = text_file.read() doc = Document(text=text, metadata={"source": filename, "type": "text"}) documents.append(doc) return documents # 3ème fichier # Set up the page configuration st.set_page_config(layout="wide") # Initialize settings def initialize_settings(): Settings.embed_model = NVIDIAEmbedding(model="nvidia/nv-embedqa-e5-v5", truncate="END") Settings.llm = NVIDIA(model="meta/llama-3.1-70b-instruct") Settings.text_splitter = SentenceSplitter(chunk_size=600) # Create index from documents def create_index(documents): vector_store = MilvusVectorStore( host = "127.0.0.1", port = 19530, dim = 1024 ) # vector_store = MilvusVectorStore(uri="./milvus_demo.db", dim=1024, overwrite=True) #For CPU only vector store storage_context = StorageContext.from_defaults(vector_store=vector_store) return VectorStoreIndex.from_documents(documents, storage_context=storage_context) # Main function to run the Streamlit app def main(): set_environment_variables() initialize_settings() col1, col2 = st.columns([1, 2]) with col1: st.title("Multimodal RAG") input_method = st.radio("Choose input method:", ("Upload Files", "Enter Directory Path")) if input_method == "Upload Files": uploaded_files = st.file_uploader("Drag and drop files here", accept_multiple_files=True) if uploaded_files and st.button("Process Files"): with st.spinner("Processing files..."): documents = load_multimodal_data(uploaded_files) st.session_state['index'] = create_index(documents) st.session_state['history'] = [] st.success("Files processed and index created!") else: directory_path = st.text_input("Enter directory path:") if directory_path and st.button("Process Directory"): if os.path.isdir(directory_path): with st.spinner("Processing directory..."): documents = load_data_from_directory(directory_path) st.session_state['index'] = create_index(documents) st.session_state['history'] = [] st.success("Directory processed and index created!") else: st.error("Invalid directory path. Please enter a valid path.") with col2: if 'index' in st.session_state: st.title("Chat") if 'history' not in st.session_state: st.session_state['history'] = [] query_engine = st.session_state['index'].as_query_engine(similarity_top_k=5, streaming=True) user_input = st.chat_input("Enter your query:") # Display chat messages chat_container = st.container() with chat_container: for message in st.session_state['history']: with st.chat_message(message["role"]): st.markdown(message["content"]) if user_input: with st.chat_message("user"): st.markdown(user_input) st.session_state['history'].append({"role": "user", "content": user_input}) with st.chat_message("assistant"): message_placeholder = st.empty() full_response = "" response = query_engine.query(user_input) for token in response.response_gen: full_response += token message_placeholder.markdown(full_response + "▌") message_placeholder.markdown(full_response) st.session_state['history'].append({"role": "assistant", "content": full_response}) # Add a clear button if st.button("Clear Chat"): st.session_state['history'] = [] st.rerun() if __name__ == "__main__": main()