orrinin commited on
Commit
8364e36
·
verified ·
1 Parent(s): 88735ad

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #using codes from mistralai official cookbook
2
+ import gradio as gr
3
+ from mistralai.client import MistralClient
4
+ from mistralai.models.chat_completion import ChatMessage
5
+ import numpy as np
6
+ import PyPDF2
7
+ import faiss
8
+ import os
9
+
10
+ mistral_api_key = os.environ.get("API_KEY")
11
+
12
+ cli = MistralClient(api_key = mistral_api_key)
13
+
14
+ def get_text_embedding(input: str):
15
+ embeddings_batch_response = cli.embeddings(
16
+ model = "mistral-embed",
17
+ input = input
18
+ )
19
+ return embeddings_batch_response.data[0].embedding
20
+
21
+ def rag_pdf(pdfs: list, question: str) -> str:
22
+ chunk_size = 4096
23
+ chunks = []
24
+ for pdf in pdfs:
25
+ chunks += [pdf[i:i + chunk_size] for i in range(0, len(pdf), chunk_size)]
26
+
27
+ text_embeddings = np.array([get_text_embedding(chunk) for chunk in chunks])
28
+ d = text_embeddings.shape[1]
29
+ index = faiss.IndexFlatL2(d)
30
+ index.add(text_embeddings)
31
+
32
+ question_embeddings = np.array([get_text_embedding(question)])
33
+ D, I = index.search(question_embeddings, k = 4)
34
+ retrieved_chunk = [chunks[i] for i in I.tolist()[0]]
35
+ text_retrieved = "\n\n".join(retrieved_chunk)
36
+ return text_retrieved
37
+
38
+ def ask_mistral(message: str, history: list):
39
+ messages = []
40
+ pdfs = message["files"]
41
+ for couple in history:
42
+ if type(couple[0]) is tuple:
43
+ pdfs += couple[0]
44
+ else:
45
+ messages.append(ChatMessage(role= "user", content = couple[0]))
46
+ messages.append(ChatMessage(role= "assistant", content = couple[1]))
47
+
48
+ if pdfs:
49
+ pdfs_extracted = []
50
+ for pdf in pdfs:
51
+ reader = PyPDF2.PdfReader(pdf)
52
+ txt = ""
53
+ for page in reader.pages:
54
+ txt += page.extract_text()
55
+ pdfs_extracted.append(txt)
56
+
57
+ retrieved_text = rag_pdf(pdfs_extracted, message["text"])
58
+ messages.append(ChatMessage(role = "user", content = retrieved_text + "\n\n" + message["text"]))
59
+ else:
60
+ messages.append(ChatMessage(role = "user", content = message["text"]))
61
+
62
+ full_response = ""
63
+ for chunk in cli.chat_stream(model = "open-mistral-7b", messages = messages, max_tokens = 1024):
64
+ full_response += chunk.choices[0].delta.content
65
+ yield full_response
66
+
67
+ app = gr.ChatInterface(
68
+ fn = ask_mistral,
69
+ title = "Ask Mistral and talk to your PDFs",
70
+ multimodal = True)
71
+ app.launch()