orrinin commited on
Commit
5379f04
·
verified ·
1 Parent(s): 1186823

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -89
app.py CHANGED
@@ -1,105 +1,143 @@
1
- #using codes from mistralai official cookbook
2
- import gradio as gr
3
- from llama_index.llms import MistralAI
4
- import numpy as np
5
- import PyPDF2
6
- import faiss
7
  import os
 
 
 
 
8
  from llama_index.core import SimpleDirectoryReader
9
- from llama_index.embeddings import MistralAIEmbedding
10
- from llama_index import ServiceContext
11
- from llama_index.core import VectorStoreIndex, StorageContext
12
- from llama_index.vector_stores.milvus import MilvusVectorStore
13
- import textwrap
14
-
15
-
16
- mistral_api_key = os.environ.get("API_KEY")
17
-
18
- cli = MistralClient(api_key = mistral_api_key)
19
-
20
- def get_text_embedding(input: str):
21
- embeddings_batch_response = cli.embeddings(
22
- model = "mistral-embed",
23
- input = input
24
- )
25
- return embeddings_batch_response.data[0].embedding
26
-
27
- def rag_pdf(pdfs: list, question: str) -> str:
28
- chunk_size = 4096
29
- chunks = []
30
- for pdf in pdfs:
31
- chunks += [pdf[i:i + chunk_size] for i in range(0, len(pdf), chunk_size)]
32
-
33
- text_embeddings = np.array([get_text_embedding(chunk) for chunk in chunks])
34
- d = text_embeddings.shape[1]
35
- index = faiss.IndexFlatL2(d)
36
- index.add(text_embeddings)
37
-
38
- question_embeddings = np.array([get_text_embedding(question)])
39
- D, I = index.search(question_embeddings, k = 4)
40
- retrieved_chunk = [chunks[i] for i in I.tolist()[0]]
41
- text_retrieved = "\n\n".join(retrieved_chunk)
42
- return text_retrieved
43
-
44
- def load_doc(path_list):
 
 
 
 
 
 
 
 
 
 
 
 
45
  documents = SimpleDirectoryReader(input_files=path).load_data()
46
- print("Document ID:", documents[0].doc_id)
47
- vector_store = MilvusVectorStore(uri="./milvus_demo.db", dim=1536, overwrite=True)
48
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
49
- index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
50
- return index
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- def ask_mistral(message: str, history: list):
 
 
55
  messages = []
56
- docs = message["files"]
57
- for couple in history:
58
- if type(couple[0]) is tuple:
59
- docs += couple[0][0]
 
 
60
  else:
61
- messages.append(ChatMessage(role= "user", content = couple[0]))
62
- messages.append(ChatMessage(role= "assistant", content = couple[1]))
63
- if docs:
64
- print(docs)
65
- index = load_doc(docs)
66
- query_engine = index.as_query_engine()
67
- response = query_engine.query(message["text"])
68
-
69
- full_response = ""
70
- for text in response.response_gen:
71
- full_response += chunk.choices[0].delta.content
72
- yield full_response
73
 
74
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
 
77
-
78
- pdfs_extracted = []
79
- for pdf in pdfs:
80
- reader = PyPDF2.PdfReader(pdf)
81
- txt = ""
82
- for page in reader.pages:
83
- txt += page.extract_text()
84
- pdfs_extracted.append(txt)
 
85
 
86
- retrieved_text = rag_pdf(pdfs_extracted, message["text"])
87
- print(f'retrieved_text: {retrieved_text}')
88
- messages.append(ChatMessage(role = "user", content = retrieved_text + "\n\n" + message["text"]))
89
- else:
90
- messages.append(ChatMessage(role = "user", content = message["text"]))
91
- print(f'messages: {messages}')
 
 
92
 
93
- full_response = ""
94
 
95
- response = cli.chat_stream(
96
- model = "open-mistral-7b",
97
- messages = messages,
98
- max_tokens = 4096)
99
-
100
- for chunk in response:
101
- full_response += chunk.choices[0].delta.content
102
- yield full_response
103
 
104
 
105
 
@@ -108,7 +146,7 @@ chatbot = gr.Chatbot()
108
  with gr.Blocks(theme="soft") as demo:
109
  gr.ChatInterface(
110
  fn = ask_mistral,
111
- title = "Ask Mistral and talk to your PDFs",
112
  multimodal = True,
113
  chatbot=chatbot,
114
  )
 
 
 
 
 
 
 
1
  import os
2
+ from bs4 import BeautifulSoup
3
+ from IPython.display import Markdown, display
4
+ from llama_index.core import Document
5
+ from llama_index.core import Settings
6
  from llama_index.core import SimpleDirectoryReader
7
+ from llama_index.core import StorageContext
8
+ from llama_index.core import VectorStoreIndex
9
+ from llama_index.readers.web import SimpleWebPageReader
10
+
11
+ from llama_index.vector_stores.chroma import ChromaVectorStore
12
+
13
+ import chromadb
14
+ import re
15
+ from llama_index.llms.gemini import Gemini
16
+ from llama_index.embeddings.gemini import GeminiEmbedding
17
+
18
+ from llama_index.core import PromptTemplate
19
+ from llama_index.core.llms import ChatMessage
20
+
21
+ import uuid
22
+
23
+ api_key = os.environ.get("API_KEY")
24
+
25
+ llm = Gemini(api_key=api_key, model_name="models/gemini-1.5-flash-latest")
26
+ gemini_embedding_model = GeminiEmbedding(api_key=api_key, model_name="models/embedding-001")
27
+
28
+
29
+
30
+
31
+ # Set Global settings
32
+ Settings.llm = llm
33
+ Settings.embed_model = gemini_embedding_model
34
+
35
+
36
+
37
+ def extract_web(url):
38
+ web_documents = SimpleWebPageReader().load_data(
39
+ [url]
40
+ )
41
+ html_content = web_documents[0].text
42
+ # Parse the data.
43
+ soup = BeautifulSoup(html_content, 'html.parser')
44
+ p_tags = soup.findAll('p')
45
+ text_content = ""
46
+ for each in p_tags:
47
+ text_content += each.text + "\n"
48
+
49
+ # Convert back to Document format
50
+ documents = [Document(text=text_content)]
51
+ option = "web"
52
+ return documents, option
53
+
54
+ def extract_doc(path):
55
  documents = SimpleDirectoryReader(input_files=path).load_data()
56
+ option = "doc"
57
+ return documents, option
 
 
 
58
 
59
 
60
+ def create_col(documents):
61
+ # Create a client and a new collection
62
+ db_path = f'database/{str(uuid.uuid4()[:4])}'
63
+ client = chromadb.PersistentClient(path=db_path)
64
+ chroma_collection = client.get_or_create_collection("quickstart")
65
+
66
+ # Create a vector store
67
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
68
+
69
+ # Create a storage context
70
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
71
+ # Create an index from the documents and save it to the disk.
72
+ VectorStoreIndex.from_documents(
73
+ documents, storage_context=storage_context
74
+ )
75
+ return db_path
76
 
77
+ def infer(message:str, history: list):
78
+ print(f'message: {message}')
79
+ print(f'history: {history}')
80
  messages = []
81
+ files_list = message["files"]
82
+
83
+
84
+ for prompt,answer in history:
85
+ if prompt is tuple:
86
+ files_list += prompt[0]
87
  else:
88
+ messages.append(ChatMessage(role= "user", content = prompt))
89
+ messages.append(ChatMessage(role= "assistant", content = answer))
90
+
 
 
 
 
 
 
 
 
 
91
 
92
+ if files_list:
93
+ documents, option = extract_doc(files_list)
94
+ else:
95
+ if message["text"].startswith("http://") or message["text"].startswith("https://"):
96
+ documents, option = extract_doc(message["text"])
97
+ elif not message["text"].startswith("http://") and not message["text"].startswith("https://") and len(history) == 0:
98
+ gr.Error("Please input an url or upload file at first.")
99
+
100
+
101
+ print(documents)
102
+ db_path = create_col(documents)
103
+
104
+ # Load from disk
105
+ load_client = chromadb.PersistentClient(path=db_path)
106
+
107
+ # Fetch the collection
108
+ chroma_collection = load_client.get_collection("quickstart")
109
+
110
+ # Fetch the vector store
111
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
112
+
113
+ # Get the index from the vector store
114
+ index = VectorStoreIndex.from_vector_store(
115
+ vector_store
116
+ )
117
 
118
 
119
+ template = (
120
+ """ You are an assistant for question-answering tasks.
121
+ Use the following context to answer the question.
122
+ If you don't know the answer, just say that you don't know.
123
+ Use five sentences maximum and keep the answer concise.\n
124
+ Question: {query_str} \nContext: {context_str} \nAnswer:"""
125
+ )
126
+ llm_prompt = PromptTemplate(template)
127
+ print(llm_prompt)
128
 
129
+ if option == "web" and len(history) == 0:
130
+ response = "Get the web data! You can ask it."
131
+ else:
132
+ question = message['text']
133
+ query_engine = index.as_query_engine(text_qa_template=llm_prompt)
134
+ response = query_engine.query(question)
135
+
136
+ return response
137
 
 
138
 
139
+
140
+
 
 
 
 
 
 
141
 
142
 
143
 
 
146
  with gr.Blocks(theme="soft") as demo:
147
  gr.ChatInterface(
148
  fn = ask_mistral,
149
+ title = "RAG demo",
150
  multimodal = True,
151
  chatbot=chatbot,
152
  )