plaguss commited on
Commit
c733466
·
verified ·
1 Parent(s): 6efe5ed

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +352 -0
app.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Any, Generator
2
+ import os
3
+ from pathlib import Path
4
+ import tarfile
5
+ from dataclasses import dataclass
6
+
7
+ import torch
8
+ import lancedb
9
+ from lancedb.embeddings import get_registry
10
+ from huggingface_hub.file_download import hf_hub_download
11
+ from huggingface_hub import InferenceClient
12
+ from transformers import AutoTokenizer
13
+ import gradio as gr
14
+
15
+
16
+ @dataclass
17
+ class Settings:
18
+ """Settings class to store useful variables for the App.
19
+ """
20
+ LANCEDB: str = "lancedb"
21
+ LANCEDB_FILE_TAR: str = "lancedb.tar.gz"
22
+ TOKEN: str = os.getenv("HF_API_TOKEN")
23
+ LOCAL_DIR: Path = Path.home() / ".cache/argilla_sdk_docs_db"
24
+ REPO_ID: str = "plaguss/argilla_sdk_docs_queries"
25
+ TABLE_NAME: str = "docs"
26
+ MODEL_NAME: str = "plaguss/bge-base-argilla-sdk-matryoshka"
27
+ DEVICE: str = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
28
+ MODEL_ID: str = "meta-llama/Meta-Llama-3-70B-Instruct"
29
+
30
+ settings = Settings()
31
+
32
+
33
+ def untar_file(source: Path) -> Path:
34
+ """Untar and decompress files which have passed by `make_tarfile`.
35
+
36
+ Args:
37
+ source (Path): Path pointing to a .tag.gz file.
38
+
39
+ Returns:
40
+ filename (Path): The filename of the file decompressed.
41
+ """
42
+ new_filename = source.parent / source.stem.replace(".tar", "")
43
+ with tarfile.open(source, "r:gz") as f:
44
+ f.extractall(source.parent)
45
+ return new_filename
46
+
47
+
48
+ def download_database(
49
+ repo_id: str,
50
+ lancedb_file: str = "lancedb.tar.gz",
51
+ local_dir: Path = Path.home() / ".cache/argilla_sdk_docs_db",
52
+ token: str = os.getenv("HF_API_TOKEN")
53
+ ) -> Path:
54
+ """Helper function to download the database. Will download a compressed lancedb stored
55
+ in a Hugging Face repository.
56
+
57
+ Args:
58
+ repo_id: Name of the repository where the databsase file is stored.
59
+ lancedb_file: Name of the compressed file containing the lancedb database.
60
+ Defaults to "lancedb.tar.gz".
61
+ local_dir: Path where the file will be donwloaded to. Defaults to
62
+ Path.home()/".cache/argilla_sdk_docs_db".
63
+ token: Token for the Hugging Face hub API. Defaults to os.getenv("HF_API_TOKEN").
64
+
65
+ Returns:
66
+ The path pointing to the database already uncompressed and ready to be used.
67
+ """
68
+ lancedb_download = Path(
69
+ hf_hub_download(
70
+ repo_id,
71
+ lancedb_file,
72
+ repo_type="dataset",
73
+ token=token,
74
+ local_dir=local_dir
75
+ )
76
+ )
77
+ return untar_file(lancedb_download)
78
+
79
+
80
+ # Get the model to create the embeddings
81
+ model = get_registry().get("sentence-transformers").create(name=settings.MODEL_NAME, device=settings.DEVICE)
82
+
83
+
84
+ class Database:
85
+ """Interaction with the vector database to retrieve the chunks.
86
+
87
+ On instantiation, will donwload the lancedb database if nos already found in
88
+ the expected location. Once ready, the only functionality available is
89
+ to retrieve the doc chunks to be used as examples for the LLM.
90
+ """
91
+ def __init__(self, settings: Settings) -> None:
92
+ self.settings = settings
93
+ self._table: lancedb.table.LanceTable = self.get_table_from_db()
94
+
95
+ def get_table_from_db(self) -> lancedb.table.LanceTable:
96
+ """Downloads the database containing the embedded docs.
97
+
98
+ If the file is not found in the expected location, will download it, and
99
+ then create the connection, open the table and pass it.
100
+
101
+ Returns:
102
+ The table of the database containing the embedded chunks.
103
+ """
104
+ lancedb_db_path = self.settings.LOCAL_DIR / self.settings.LANCEDB
105
+
106
+ if not lancedb_db_path.exists():
107
+ lancedb_db_path = download_database(
108
+ self.settings.REPO_ID,
109
+ lancedb_file=self.settings.LANCEDB_FILE_TAR,
110
+ local_dir=self.settings.LOCAL_DIR,
111
+ token=self.settings.TOKEN
112
+ )
113
+
114
+ db = lancedb.connect(str(lancedb_db_path))
115
+ table = db.open_table(self.settings.TABLE_NAME)
116
+ return table
117
+
118
+ def retrieve_doc_chunks(self, query: str, limit: int = 12, hard_limit: int = 4) -> str:
119
+ """Search for similar queries in the database, and return a list with
120
+
121
+ TODO: SPLIT IN TWO SEPARATE FUNCTIONS TO PREPARE THE CONTEXT.
122
+
123
+ Args:
124
+ query (str): _description_
125
+ limit (int, optional): _description_. Defaults to 12.
126
+ hard_limit (int, optional): _description_. Defaults to 4.
127
+
128
+ Returns:
129
+ str: _description_
130
+ """
131
+ # Embed the query to use our custom model instead of the default one.
132
+ embedded_query = model.generate_embeddings([query])
133
+ field_to_retrieve = "text"
134
+ retrieved = (
135
+ self._table
136
+ .search(embedded_query[0])
137
+ .metric("cosine")
138
+ .limit(limit)
139
+ .select([field_to_retrieve]) # Just grab the chunk to use for context
140
+ .to_list()
141
+ )
142
+ # We have repeated questions (up to 4) for a given chunk, so we may get repeated chunks.
143
+ # Request more than necessary and filter them afterwards
144
+ responses = []
145
+ unique_responses = set()
146
+
147
+ for item in retrieved:
148
+ chunk = item["text"]
149
+ if chunk not in unique_responses:
150
+ unique_responses.add(chunk)
151
+ responses.append(chunk)
152
+
153
+ context = ""
154
+ for i, item in enumerate(responses[:hard_limit]):
155
+ if i > 0:
156
+ context += "\n\n"
157
+ context += f"---\n{item}"
158
+ return context
159
+
160
+
161
+ database = Database(settings=settings)
162
+
163
+
164
+ def get_client_and_tokenizer(
165
+ model_id: str = settings.MODEL_ID,
166
+ tokenizer_id: Optional[str] = None
167
+ ) -> tuple[InferenceClient, AutoTokenizer]:
168
+ """Obtains the inference client and the tokenizer corresponding to the model.
169
+
170
+ Args:
171
+ model_id: The name of the model. Currently it must be one in the free tier.
172
+ Defaults to "meta-llama/Meta-Llama-3-70B-Instruct".
173
+ tokenizer_id: The name of the corresponding tokenizer. Defaults to None,
174
+ in which case it will use the same as the `model_id`.
175
+
176
+ Returns:
177
+ The client and tokenizer chosen.
178
+ """
179
+ if tokenizer_id is None:
180
+ tokenizer_id = model_id
181
+
182
+ client = InferenceClient()
183
+ base_url = client._resolve_url(
184
+ model=model_id, task="text-generation"
185
+ )
186
+ # Note: We could move to the AsyncClient
187
+ client = InferenceClient(
188
+ model=base_url,
189
+ token=os.getenv("HF_API_TOKEN")
190
+ )
191
+
192
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
193
+ return client, tokenizer
194
+
195
+
196
+ client_kwargs = {
197
+ "stream": True,
198
+ "max_new_tokens": 512,
199
+ "do_sample": False,
200
+ "typical_p": None,
201
+ "repetition_penalty": None,
202
+ "temperature": 0.3,
203
+ "top_p": None,
204
+ "top_k": None,
205
+ "stop_sequences": ["<|eot_id|>", "<|end_of_text|>"] if settings.MODEL_ID.startswith("meta-llama/Meta-Llama-3") else None,
206
+ "seed": None,
207
+ }
208
+
209
+
210
+ client, tokenizer = get_client_and_tokenizer()
211
+
212
+ SYSTEM_PROMPT = """\
213
+ You are a support expert in Argilla SDK, whose goal is help users with their questions.
214
+ As a trustworthy expert, you must provide truthful answers to questions using only the provided documentation snippets, not prior knowledge.
215
+ Here are guidelines you must follow when responding to user questions:
216
+
217
+ ##Purpose and Functionality**
218
+ - Answer questions related to the Argilla SDK.
219
+ - Provide clear and concise explanations, relevant code snippets, and guidance depending on the user's question and intent.
220
+ - Ensure users succeed in effectively understanding and using Argilla's features.
221
+ - Provide accurate responses to the user's questions.
222
+
223
+ **Specificity**
224
+ - Be specific and provide details only when required.
225
+ - Where necessary, ask clarifying questions to better understand the user's question.
226
+ - Provide accurate and context-specific code excerpts with clear explanations.
227
+ - Ensure the code snippets are syntactically correct, functional, and run without errors.
228
+ - For code troubleshooting-related questions, focus on the code snippet and clearly explain the issue and how to resolve it.
229
+ - Avoid boilerplate code such as imports, installs, etc.
230
+
231
+ **Reliability**
232
+ - Your responses must rely only on the provided context, not prior knowledge.
233
+ - If the provided context doesn't help answer the question, just say you don't know.
234
+ - When providing code snippets, ensure the functions, classes, or methods are derived only from the context and not prior knowledge.
235
+ - Where the provided context is insufficient to respond faithfully, admit uncertainty.
236
+ - Remind the user of your specialization in Argilla SDK support when a question is outside your domain of expertise.
237
+ - Redirect the user to the appropriate support channels - Argilla [community](https://join.slack.com/t/rubrixworkspace/shared_invite/zt-whigkyjn-a3IUJLD7gDbTZ0rKlvcJ5g) when the question is outside your capabilities or you do not have enough context to answer the question.
238
+
239
+ **Response Style**
240
+ - Use clear, concise, professional language suitable for technical support
241
+ - Do not refer to the context in the response (e.g., "As mentioned in the context...") instead, provide the information directly in the response.
242
+
243
+ **Example**:
244
+
245
+ The correct answer to the user's query
246
+
247
+ Steps to solve the problem:
248
+ - **Step 1**: ...
249
+ - **Step 2**: ...
250
+ ...
251
+
252
+ Here's a code snippet
253
+
254
+ ```python
255
+ # Code example
256
+ ...
257
+ ```
258
+
259
+ **Explanation**:
260
+
261
+ - Point 1
262
+ - Point 2
263
+ ...
264
+ """
265
+
266
+ ARGILLA_BOT_TEMPLATE = """\
267
+ Please provide an answer to the following question related to Argilla's new SDK.
268
+
269
+ You can make use of the chunks of documents in the context to help you generating the response.
270
+
271
+ ## Query:
272
+ {message}
273
+
274
+ ## Context:
275
+ {context}
276
+ """
277
+
278
+
279
+ def prepare_input(message: str, history: list[tuple[str, str]]) -> str:
280
+ """Prepares the input to be passed to the LLM.
281
+
282
+ Args:
283
+ message: Message from the user, the query.
284
+ history: Previous list of messages from the user and the answers, as a list
285
+ of tuples with user/assistant messages.
286
+
287
+ Returns:
288
+ The string with the template formatted to be sent to the LLM.
289
+ """
290
+ # Retrieve the context from the database
291
+ context = database.retrieve_doc_chunks(message)
292
+
293
+ # Prepare the conversation for the model.
294
+ conversation = []
295
+ for human, bot in history:
296
+ conversation.append({"role": "user", "content": human})
297
+ conversation.append({"role": "assistant", "content": bot})
298
+
299
+ conversation.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
300
+ conversation.append(
301
+ {
302
+ "role": "user",
303
+ "content": ARGILLA_BOT_TEMPLATE.format(message=message, context=context),
304
+ }
305
+ )
306
+
307
+ return tokenizer.apply_chat_template(
308
+ [conversation],
309
+ tokenize=False,
310
+ add_generation_prompt=True,
311
+ )[0]
312
+
313
+
314
+ def chatty(message: str, history: list[tuple[str, str]]) -> Generator[str, None, None]:
315
+ """Main function of the app, contains the interaction with the LLM.
316
+
317
+ Args:
318
+ message: Message from the user, the query.
319
+ history: Previous list of messages from the user and the answers, as a list
320
+ of tuples with user/assistant messages.
321
+
322
+ Yields:
323
+ The streaming response, it's printed in the interface as it's being received.
324
+ """
325
+ prompt = prepare_input(message, history)
326
+
327
+ partial_message = ""
328
+ for token_stream in client.text_generation(prompt=prompt, **client_kwargs):
329
+ partial_message += token_stream
330
+ yield partial_message
331
+
332
+
333
+
334
+ if __name__ == "__main__":
335
+
336
+ import gradio as gr
337
+
338
+ gr.ChatInterface(
339
+ chatty,
340
+ chatbot=gr.Chatbot(height=600),
341
+ textbox=gr.Textbox(placeholder="Ask me about the new argilla SDK", container=False, scale=7),
342
+ title="Argilla SDK Chatbot",
343
+ description="Ask a question about Argilla SDK",
344
+ theme="soft",
345
+ examples=[
346
+ "How can I connect to an argilla server?",
347
+ "How can I access a dataset?",
348
+ "How can I get the current user?"
349
+ ],
350
+ cache_examples=True,
351
+ retry_btn=None,
352
+ ).launch()