|
|
|
import gradio as gr |
|
from llama_index.llms import MistralAI |
|
import numpy as np |
|
import PyPDF2 |
|
import faiss |
|
import os |
|
from llama_index.core import SimpleDirectoryReader |
|
from llama_index.embeddings import MistralAIEmbedding |
|
from llama_index import ServiceContext |
|
from llama_index.core import VectorStoreIndex, StorageContext |
|
from llama_index.vector_stores.milvus import MilvusVectorStore |
|
import textwrap |
|
|
|
|
|
mistral_api_key = os.environ.get("API_KEY") |
|
|
|
cli = MistralClient(api_key = mistral_api_key) |
|
|
|
def get_text_embedding(input: str): |
|
embeddings_batch_response = cli.embeddings( |
|
model = "mistral-embed", |
|
input = input |
|
) |
|
return embeddings_batch_response.data[0].embedding |
|
|
|
def rag_pdf(pdfs: list, question: str) -> str: |
|
chunk_size = 4096 |
|
chunks = [] |
|
for pdf in pdfs: |
|
chunks += [pdf[i:i + chunk_size] for i in range(0, len(pdf), chunk_size)] |
|
|
|
text_embeddings = np.array([get_text_embedding(chunk) for chunk in chunks]) |
|
d = text_embeddings.shape[1] |
|
index = faiss.IndexFlatL2(d) |
|
index.add(text_embeddings) |
|
|
|
question_embeddings = np.array([get_text_embedding(question)]) |
|
D, I = index.search(question_embeddings, k = 4) |
|
retrieved_chunk = [chunks[i] for i in I.tolist()[0]] |
|
text_retrieved = "\n\n".join(retrieved_chunk) |
|
return text_retrieved |
|
|
|
def load_doc(path_list): |
|
documents = SimpleDirectoryReader(input_files=path).load_data() |
|
print("Document ID:", documents[0].doc_id) |
|
vector_store = MilvusVectorStore(uri="./milvus_demo.db", dim=1536, overwrite=True) |
|
storage_context = StorageContext.from_defaults(vector_store=vector_store) |
|
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context) |
|
return index |
|
|
|
|
|
|
|
def ask_mistral(message: str, history: list): |
|
messages = [] |
|
docs = message["files"] |
|
for couple in history: |
|
if type(couple[0]) is tuple: |
|
docs += couple[0][0] |
|
else: |
|
messages.append(ChatMessage(role= "user", content = couple[0])) |
|
messages.append(ChatMessage(role= "assistant", content = couple[1])) |
|
if docs: |
|
print(docs) |
|
index = load_doc(docs) |
|
query_engine = index.as_query_engine() |
|
response = query_engine.query(message["text"]) |
|
|
|
full_response = "" |
|
for text in response.response_gen: |
|
full_response += chunk.choices[0].delta.content |
|
yield full_response |
|
|
|
|
|
|
|
|
|
|
|
pdfs_extracted = [] |
|
for pdf in pdfs: |
|
reader = PyPDF2.PdfReader(pdf) |
|
txt = "" |
|
for page in reader.pages: |
|
txt += page.extract_text() |
|
pdfs_extracted.append(txt) |
|
|
|
retrieved_text = rag_pdf(pdfs_extracted, message["text"]) |
|
print(f'retrieved_text: {retrieved_text}') |
|
messages.append(ChatMessage(role = "user", content = retrieved_text + "\n\n" + message["text"])) |
|
else: |
|
messages.append(ChatMessage(role = "user", content = message["text"])) |
|
print(f'messages: {messages}') |
|
|
|
full_response = "" |
|
|
|
response = cli.chat_stream( |
|
model = "open-mistral-7b", |
|
messages = messages, |
|
max_tokens = 4096) |
|
|
|
for chunk in response: |
|
full_response += chunk.choices[0].delta.content |
|
yield full_response |
|
|
|
|
|
|
|
chatbot = gr.Chatbot() |
|
|
|
with gr.Blocks(theme="soft") as demo: |
|
gr.ChatInterface( |
|
fn = ask_mistral, |
|
title = "Ask Mistral and talk to your PDFs", |
|
multimodal = True, |
|
chatbot=chatbot, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(api_open=False).launch(show_api=False, share=False) |