File size: 3,689 Bytes
8364e36 b2c2e74 8364e36 b2c2e74 dd3fe36 8364e36 dd3fe36 8364e36 dd3fe36 8364e36 dd3fe36 8364e36 b2c2e74 dd3fe36 8364e36 b2c2e74 8364e36 b2c2e74 8364e36 b2c2e74 8364e36 b2c2e74 8364e36 51b1469 8364e36 51b1469 8364e36 51b1469 8364e36 dd3fe36 de693c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
#using codes from mistralai official cookbook
import gradio as gr
from llama_index.llms import MistralAI
import numpy as np
import PyPDF2
import faiss
import os
from llama_index.core import SimpleDirectoryReader
from llama_index.embeddings import MistralAIEmbedding
from llama_index import ServiceContext
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.vector_stores.milvus import MilvusVectorStore
import textwrap
mistral_api_key = os.environ.get("API_KEY")
cli = MistralClient(api_key = mistral_api_key)
def get_text_embedding(input: str):
embeddings_batch_response = cli.embeddings(
model = "mistral-embed",
input = input
)
return embeddings_batch_response.data[0].embedding
def rag_pdf(pdfs: list, question: str) -> str:
chunk_size = 4096
chunks = []
for pdf in pdfs:
chunks += [pdf[i:i + chunk_size] for i in range(0, len(pdf), chunk_size)]
text_embeddings = np.array([get_text_embedding(chunk) for chunk in chunks])
d = text_embeddings.shape[1]
index = faiss.IndexFlatL2(d)
index.add(text_embeddings)
question_embeddings = np.array([get_text_embedding(question)])
D, I = index.search(question_embeddings, k = 4)
retrieved_chunk = [chunks[i] for i in I.tolist()[0]]
text_retrieved = "\n\n".join(retrieved_chunk)
return text_retrieved
def load_doc(path_list):
documents = SimpleDirectoryReader(input_files=path).load_data()
print("Document ID:", documents[0].doc_id)
vector_store = MilvusVectorStore(uri="./milvus_demo.db", dim=1536, overwrite=True)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
return index
def ask_mistral(message: str, history: list):
messages = []
docs = message["files"]
for couple in history:
if type(couple[0]) is tuple:
docs += couple[0][0]
else:
messages.append(ChatMessage(role= "user", content = couple[0]))
messages.append(ChatMessage(role= "assistant", content = couple[1]))
if docs:
print(docs)
index = load_doc(docs)
query_engine = index.as_query_engine()
response = query_engine.query(message["text"])
full_response = ""
for text in response.response_gen:
full_response += chunk.choices[0].delta.content
yield full_response
pdfs_extracted = []
for pdf in pdfs:
reader = PyPDF2.PdfReader(pdf)
txt = ""
for page in reader.pages:
txt += page.extract_text()
pdfs_extracted.append(txt)
retrieved_text = rag_pdf(pdfs_extracted, message["text"])
print(f'retrieved_text: {retrieved_text}')
messages.append(ChatMessage(role = "user", content = retrieved_text + "\n\n" + message["text"]))
else:
messages.append(ChatMessage(role = "user", content = message["text"]))
print(f'messages: {messages}')
full_response = ""
response = cli.chat_stream(
model = "open-mistral-7b",
messages = messages,
max_tokens = 4096)
for chunk in response:
full_response += chunk.choices[0].delta.content
yield full_response
chatbot = gr.Chatbot()
with gr.Blocks(theme="soft") as demo:
gr.ChatInterface(
fn = ask_mistral,
title = "Ask Mistral and talk to your PDFs",
multimodal = True,
chatbot=chatbot,
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False) |