#using codes from mistralai official cookbook import gradio as gr from llama_index.llms import MistralAI import numpy as np import PyPDF2 import faiss import os from llama_index.core import SimpleDirectoryReader from llama_index.embeddings import MistralAIEmbedding from llama_index import ServiceContext from llama_index.core import VectorStoreIndex, StorageContext from llama_index.vector_stores.milvus import MilvusVectorStore import textwrap mistral_api_key = os.environ.get("API_KEY") cli = MistralClient(api_key = mistral_api_key) def get_text_embedding(input: str): embeddings_batch_response = cli.embeddings( model = "mistral-embed", input = input ) return embeddings_batch_response.data[0].embedding def rag_pdf(pdfs: list, question: str) -> str: chunk_size = 4096 chunks = [] for pdf in pdfs: chunks += [pdf[i:i + chunk_size] for i in range(0, len(pdf), chunk_size)] text_embeddings = np.array([get_text_embedding(chunk) for chunk in chunks]) d = text_embeddings.shape[1] index = faiss.IndexFlatL2(d) index.add(text_embeddings) question_embeddings = np.array([get_text_embedding(question)]) D, I = index.search(question_embeddings, k = 4) retrieved_chunk = [chunks[i] for i in I.tolist()[0]] text_retrieved = "\n\n".join(retrieved_chunk) return text_retrieved def load_doc(path_list): documents = SimpleDirectoryReader(input_files=path).load_data() print("Document ID:", documents[0].doc_id) vector_store = MilvusVectorStore(uri="./milvus_demo.db", dim=1536, overwrite=True) storage_context = StorageContext.from_defaults(vector_store=vector_store) index = VectorStoreIndex.from_documents(documents, storage_context=storage_context) return index def ask_mistral(message: str, history: list): messages = [] docs = message["files"] for couple in history: if type(couple[0]) is tuple: docs += couple[0][0] else: messages.append(ChatMessage(role= "user", content = couple[0])) messages.append(ChatMessage(role= "assistant", content = couple[1])) if docs: print(docs) index = load_doc(docs) query_engine = index.as_query_engine() response = query_engine.query(message["text"]) full_response = "" for text in response.response_gen: full_response += chunk.choices[0].delta.content yield full_response pdfs_extracted = [] for pdf in pdfs: reader = PyPDF2.PdfReader(pdf) txt = "" for page in reader.pages: txt += page.extract_text() pdfs_extracted.append(txt) retrieved_text = rag_pdf(pdfs_extracted, message["text"]) print(f'retrieved_text: {retrieved_text}') messages.append(ChatMessage(role = "user", content = retrieved_text + "\n\n" + message["text"])) else: messages.append(ChatMessage(role = "user", content = message["text"])) print(f'messages: {messages}') full_response = "" response = cli.chat_stream( model = "open-mistral-7b", messages = messages, max_tokens = 4096) for chunk in response: full_response += chunk.choices[0].delta.content yield full_response chatbot = gr.Chatbot() with gr.Blocks(theme="soft") as demo: gr.ChatInterface( fn = ask_mistral, title = "Ask Mistral and talk to your PDFs", multimodal = True, chatbot=chatbot, ) if __name__ == "__main__": demo.queue(api_open=False).launch(show_api=False, share=False)