File size: 3,689 Bytes
8364e36
 
b2c2e74
8364e36
 
 
 
b2c2e74
 
 
 
 
 
dd3fe36
8364e36
 
 
dd3fe36
8364e36
dd3fe36
8364e36
 
 
 
 
 
dd3fe36
8364e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c2e74
 
 
 
 
 
 
 
 
 
dd3fe36
8364e36
b2c2e74
8364e36
 
b2c2e74
8364e36
 
 
b2c2e74
 
 
 
 
 
 
 
 
 
 
 
 
8364e36
b2c2e74
8364e36
 
 
 
 
 
 
 
 
51b1469
8364e36
 
 
51b1469
 
8364e36
51b1469
 
 
 
 
 
 
8364e36
 
 
dd3fe36
 
de693c7
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#using codes from mistralai official cookbook
import gradio as gr
from llama_index.llms import MistralAI
import numpy as np
import PyPDF2
import faiss
import os
from llama_index.core import SimpleDirectoryReader
from llama_index.embeddings import MistralAIEmbedding
from llama_index import ServiceContext
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.vector_stores.milvus import MilvusVectorStore
import textwrap


mistral_api_key = os.environ.get("API_KEY")

cli = MistralClient(api_key = mistral_api_key)

def get_text_embedding(input: str):
    embeddings_batch_response = cli.embeddings(
          model = "mistral-embed",
          input = input
      )
    return embeddings_batch_response.data[0].embedding

def rag_pdf(pdfs: list, question: str) -> str:
    chunk_size = 4096
    chunks = []
    for pdf in pdfs:
        chunks += [pdf[i:i + chunk_size] for i in range(0, len(pdf), chunk_size)]

    text_embeddings = np.array([get_text_embedding(chunk) for chunk in chunks])
    d = text_embeddings.shape[1]
    index = faiss.IndexFlatL2(d)
    index.add(text_embeddings)

    question_embeddings = np.array([get_text_embedding(question)])
    D, I = index.search(question_embeddings, k = 4)
    retrieved_chunk = [chunks[i] for i in I.tolist()[0]]
    text_retrieved = "\n\n".join(retrieved_chunk)
    return text_retrieved

def load_doc(path_list):
    documents = SimpleDirectoryReader(input_files=path).load_data()
    print("Document ID:", documents[0].doc_id)
    vector_store = MilvusVectorStore(uri="./milvus_demo.db", dim=1536, overwrite=True)
    storage_context = StorageContext.from_defaults(vector_store=vector_store)
    index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
    return index



def ask_mistral(message: str, history: list):
    messages = []
    docs = message["files"]
    for couple in history:
        if type(couple[0]) is tuple:
            docs += couple[0][0]
        else:
            messages.append(ChatMessage(role= "user", content = couple[0]))
            messages.append(ChatMessage(role= "assistant", content = couple[1]))
    if docs:
        print(docs)
        index = load_doc(docs)
        query_engine = index.as_query_engine()
        response = query_engine.query(message["text"])

        full_response = ""
        for text in response.response_gen:
            full_response += chunk.choices[0].delta.content
            yield full_response        
            
        


        
        pdfs_extracted = []
        for pdf in pdfs:
            reader = PyPDF2.PdfReader(pdf)
            txt = ""
            for page in reader.pages:
                txt += page.extract_text()
            pdfs_extracted.append(txt)

        retrieved_text = rag_pdf(pdfs_extracted, message["text"])
        print(f'retrieved_text: {retrieved_text}')
        messages.append(ChatMessage(role = "user", content = retrieved_text + "\n\n" + message["text"]))
    else:
        messages.append(ChatMessage(role = "user", content = message["text"]))
    print(f'messages: {messages}')
    
    full_response = ""

    response = cli.chat_stream(
        model = "open-mistral-7b",
        messages = messages, 
        max_tokens = 4096)
          
    for chunk in response:
        full_response += chunk.choices[0].delta.content
        yield full_response



chatbot = gr.Chatbot()

with gr.Blocks(theme="soft") as demo:
    gr.ChatInterface(
        fn = ask_mistral,
        title = "Ask Mistral and talk to your PDFs", 
        multimodal = True,
        chatbot=chatbot,
    )

if __name__ == "__main__":
    demo.queue(api_open=False).launch(show_api=False, share=False)