Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Any, Generator
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
import tarfile
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import lancedb
|
9 |
+
from lancedb.embeddings import get_registry
|
10 |
+
from huggingface_hub.file_download import hf_hub_download
|
11 |
+
from huggingface_hub import InferenceClient
|
12 |
+
from transformers import AutoTokenizer
|
13 |
+
import gradio as gr
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class Settings:
|
18 |
+
"""Settings class to store useful variables for the App.
|
19 |
+
"""
|
20 |
+
LANCEDB: str = "lancedb"
|
21 |
+
LANCEDB_FILE_TAR: str = "lancedb.tar.gz"
|
22 |
+
TOKEN: str = os.getenv("HF_API_TOKEN")
|
23 |
+
LOCAL_DIR: Path = Path.home() / ".cache/argilla_sdk_docs_db"
|
24 |
+
REPO_ID: str = "plaguss/argilla_sdk_docs_queries"
|
25 |
+
TABLE_NAME: str = "docs"
|
26 |
+
MODEL_NAME: str = "plaguss/bge-base-argilla-sdk-matryoshka"
|
27 |
+
DEVICE: str = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
+
MODEL_ID: str = "meta-llama/Meta-Llama-3-70B-Instruct"
|
29 |
+
|
30 |
+
settings = Settings()
|
31 |
+
|
32 |
+
|
33 |
+
def untar_file(source: Path) -> Path:
|
34 |
+
"""Untar and decompress files which have passed by `make_tarfile`.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
source (Path): Path pointing to a .tag.gz file.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
filename (Path): The filename of the file decompressed.
|
41 |
+
"""
|
42 |
+
new_filename = source.parent / source.stem.replace(".tar", "")
|
43 |
+
with tarfile.open(source, "r:gz") as f:
|
44 |
+
f.extractall(source.parent)
|
45 |
+
return new_filename
|
46 |
+
|
47 |
+
|
48 |
+
def download_database(
|
49 |
+
repo_id: str,
|
50 |
+
lancedb_file: str = "lancedb.tar.gz",
|
51 |
+
local_dir: Path = Path.home() / ".cache/argilla_sdk_docs_db",
|
52 |
+
token: str = os.getenv("HF_API_TOKEN")
|
53 |
+
) -> Path:
|
54 |
+
"""Helper function to download the database. Will download a compressed lancedb stored
|
55 |
+
in a Hugging Face repository.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
repo_id: Name of the repository where the databsase file is stored.
|
59 |
+
lancedb_file: Name of the compressed file containing the lancedb database.
|
60 |
+
Defaults to "lancedb.tar.gz".
|
61 |
+
local_dir: Path where the file will be donwloaded to. Defaults to
|
62 |
+
Path.home()/".cache/argilla_sdk_docs_db".
|
63 |
+
token: Token for the Hugging Face hub API. Defaults to os.getenv("HF_API_TOKEN").
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
The path pointing to the database already uncompressed and ready to be used.
|
67 |
+
"""
|
68 |
+
lancedb_download = Path(
|
69 |
+
hf_hub_download(
|
70 |
+
repo_id,
|
71 |
+
lancedb_file,
|
72 |
+
repo_type="dataset",
|
73 |
+
token=token,
|
74 |
+
local_dir=local_dir
|
75 |
+
)
|
76 |
+
)
|
77 |
+
return untar_file(lancedb_download)
|
78 |
+
|
79 |
+
|
80 |
+
# Get the model to create the embeddings
|
81 |
+
model = get_registry().get("sentence-transformers").create(name=settings.MODEL_NAME, device=settings.DEVICE)
|
82 |
+
|
83 |
+
|
84 |
+
class Database:
|
85 |
+
"""Interaction with the vector database to retrieve the chunks.
|
86 |
+
|
87 |
+
On instantiation, will donwload the lancedb database if nos already found in
|
88 |
+
the expected location. Once ready, the only functionality available is
|
89 |
+
to retrieve the doc chunks to be used as examples for the LLM.
|
90 |
+
"""
|
91 |
+
def __init__(self, settings: Settings) -> None:
|
92 |
+
self.settings = settings
|
93 |
+
self._table: lancedb.table.LanceTable = self.get_table_from_db()
|
94 |
+
|
95 |
+
def get_table_from_db(self) -> lancedb.table.LanceTable:
|
96 |
+
"""Downloads the database containing the embedded docs.
|
97 |
+
|
98 |
+
If the file is not found in the expected location, will download it, and
|
99 |
+
then create the connection, open the table and pass it.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
The table of the database containing the embedded chunks.
|
103 |
+
"""
|
104 |
+
lancedb_db_path = self.settings.LOCAL_DIR / self.settings.LANCEDB
|
105 |
+
|
106 |
+
if not lancedb_db_path.exists():
|
107 |
+
lancedb_db_path = download_database(
|
108 |
+
self.settings.REPO_ID,
|
109 |
+
lancedb_file=self.settings.LANCEDB_FILE_TAR,
|
110 |
+
local_dir=self.settings.LOCAL_DIR,
|
111 |
+
token=self.settings.TOKEN
|
112 |
+
)
|
113 |
+
|
114 |
+
db = lancedb.connect(str(lancedb_db_path))
|
115 |
+
table = db.open_table(self.settings.TABLE_NAME)
|
116 |
+
return table
|
117 |
+
|
118 |
+
def retrieve_doc_chunks(self, query: str, limit: int = 12, hard_limit: int = 4) -> str:
|
119 |
+
"""Search for similar queries in the database, and return a list with
|
120 |
+
|
121 |
+
TODO: SPLIT IN TWO SEPARATE FUNCTIONS TO PREPARE THE CONTEXT.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
query (str): _description_
|
125 |
+
limit (int, optional): _description_. Defaults to 12.
|
126 |
+
hard_limit (int, optional): _description_. Defaults to 4.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
str: _description_
|
130 |
+
"""
|
131 |
+
# Embed the query to use our custom model instead of the default one.
|
132 |
+
embedded_query = model.generate_embeddings([query])
|
133 |
+
field_to_retrieve = "text"
|
134 |
+
retrieved = (
|
135 |
+
self._table
|
136 |
+
.search(embedded_query[0])
|
137 |
+
.metric("cosine")
|
138 |
+
.limit(limit)
|
139 |
+
.select([field_to_retrieve]) # Just grab the chunk to use for context
|
140 |
+
.to_list()
|
141 |
+
)
|
142 |
+
# We have repeated questions (up to 4) for a given chunk, so we may get repeated chunks.
|
143 |
+
# Request more than necessary and filter them afterwards
|
144 |
+
responses = []
|
145 |
+
unique_responses = set()
|
146 |
+
|
147 |
+
for item in retrieved:
|
148 |
+
chunk = item["text"]
|
149 |
+
if chunk not in unique_responses:
|
150 |
+
unique_responses.add(chunk)
|
151 |
+
responses.append(chunk)
|
152 |
+
|
153 |
+
context = ""
|
154 |
+
for i, item in enumerate(responses[:hard_limit]):
|
155 |
+
if i > 0:
|
156 |
+
context += "\n\n"
|
157 |
+
context += f"---\n{item}"
|
158 |
+
return context
|
159 |
+
|
160 |
+
|
161 |
+
database = Database(settings=settings)
|
162 |
+
|
163 |
+
|
164 |
+
def get_client_and_tokenizer(
|
165 |
+
model_id: str = settings.MODEL_ID,
|
166 |
+
tokenizer_id: Optional[str] = None
|
167 |
+
) -> tuple[InferenceClient, AutoTokenizer]:
|
168 |
+
"""Obtains the inference client and the tokenizer corresponding to the model.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
model_id: The name of the model. Currently it must be one in the free tier.
|
172 |
+
Defaults to "meta-llama/Meta-Llama-3-70B-Instruct".
|
173 |
+
tokenizer_id: The name of the corresponding tokenizer. Defaults to None,
|
174 |
+
in which case it will use the same as the `model_id`.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
The client and tokenizer chosen.
|
178 |
+
"""
|
179 |
+
if tokenizer_id is None:
|
180 |
+
tokenizer_id = model_id
|
181 |
+
|
182 |
+
client = InferenceClient()
|
183 |
+
base_url = client._resolve_url(
|
184 |
+
model=model_id, task="text-generation"
|
185 |
+
)
|
186 |
+
# Note: We could move to the AsyncClient
|
187 |
+
client = InferenceClient(
|
188 |
+
model=base_url,
|
189 |
+
token=os.getenv("HF_API_TOKEN")
|
190 |
+
)
|
191 |
+
|
192 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
193 |
+
return client, tokenizer
|
194 |
+
|
195 |
+
|
196 |
+
client_kwargs = {
|
197 |
+
"stream": True,
|
198 |
+
"max_new_tokens": 512,
|
199 |
+
"do_sample": False,
|
200 |
+
"typical_p": None,
|
201 |
+
"repetition_penalty": None,
|
202 |
+
"temperature": 0.3,
|
203 |
+
"top_p": None,
|
204 |
+
"top_k": None,
|
205 |
+
"stop_sequences": ["<|eot_id|>", "<|end_of_text|>"] if settings.MODEL_ID.startswith("meta-llama/Meta-Llama-3") else None,
|
206 |
+
"seed": None,
|
207 |
+
}
|
208 |
+
|
209 |
+
|
210 |
+
client, tokenizer = get_client_and_tokenizer()
|
211 |
+
|
212 |
+
SYSTEM_PROMPT = """\
|
213 |
+
You are a support expert in Argilla SDK, whose goal is help users with their questions.
|
214 |
+
As a trustworthy expert, you must provide truthful answers to questions using only the provided documentation snippets, not prior knowledge.
|
215 |
+
Here are guidelines you must follow when responding to user questions:
|
216 |
+
|
217 |
+
##Purpose and Functionality**
|
218 |
+
- Answer questions related to the Argilla SDK.
|
219 |
+
- Provide clear and concise explanations, relevant code snippets, and guidance depending on the user's question and intent.
|
220 |
+
- Ensure users succeed in effectively understanding and using Argilla's features.
|
221 |
+
- Provide accurate responses to the user's questions.
|
222 |
+
|
223 |
+
**Specificity**
|
224 |
+
- Be specific and provide details only when required.
|
225 |
+
- Where necessary, ask clarifying questions to better understand the user's question.
|
226 |
+
- Provide accurate and context-specific code excerpts with clear explanations.
|
227 |
+
- Ensure the code snippets are syntactically correct, functional, and run without errors.
|
228 |
+
- For code troubleshooting-related questions, focus on the code snippet and clearly explain the issue and how to resolve it.
|
229 |
+
- Avoid boilerplate code such as imports, installs, etc.
|
230 |
+
|
231 |
+
**Reliability**
|
232 |
+
- Your responses must rely only on the provided context, not prior knowledge.
|
233 |
+
- If the provided context doesn't help answer the question, just say you don't know.
|
234 |
+
- When providing code snippets, ensure the functions, classes, or methods are derived only from the context and not prior knowledge.
|
235 |
+
- Where the provided context is insufficient to respond faithfully, admit uncertainty.
|
236 |
+
- Remind the user of your specialization in Argilla SDK support when a question is outside your domain of expertise.
|
237 |
+
- Redirect the user to the appropriate support channels - Argilla [community](https://join.slack.com/t/rubrixworkspace/shared_invite/zt-whigkyjn-a3IUJLD7gDbTZ0rKlvcJ5g) when the question is outside your capabilities or you do not have enough context to answer the question.
|
238 |
+
|
239 |
+
**Response Style**
|
240 |
+
- Use clear, concise, professional language suitable for technical support
|
241 |
+
- Do not refer to the context in the response (e.g., "As mentioned in the context...") instead, provide the information directly in the response.
|
242 |
+
|
243 |
+
**Example**:
|
244 |
+
|
245 |
+
The correct answer to the user's query
|
246 |
+
|
247 |
+
Steps to solve the problem:
|
248 |
+
- **Step 1**: ...
|
249 |
+
- **Step 2**: ...
|
250 |
+
...
|
251 |
+
|
252 |
+
Here's a code snippet
|
253 |
+
|
254 |
+
```python
|
255 |
+
# Code example
|
256 |
+
...
|
257 |
+
```
|
258 |
+
|
259 |
+
**Explanation**:
|
260 |
+
|
261 |
+
- Point 1
|
262 |
+
- Point 2
|
263 |
+
...
|
264 |
+
"""
|
265 |
+
|
266 |
+
ARGILLA_BOT_TEMPLATE = """\
|
267 |
+
Please provide an answer to the following question related to Argilla's new SDK.
|
268 |
+
|
269 |
+
You can make use of the chunks of documents in the context to help you generating the response.
|
270 |
+
|
271 |
+
## Query:
|
272 |
+
{message}
|
273 |
+
|
274 |
+
## Context:
|
275 |
+
{context}
|
276 |
+
"""
|
277 |
+
|
278 |
+
|
279 |
+
def prepare_input(message: str, history: list[tuple[str, str]]) -> str:
|
280 |
+
"""Prepares the input to be passed to the LLM.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
message: Message from the user, the query.
|
284 |
+
history: Previous list of messages from the user and the answers, as a list
|
285 |
+
of tuples with user/assistant messages.
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
The string with the template formatted to be sent to the LLM.
|
289 |
+
"""
|
290 |
+
# Retrieve the context from the database
|
291 |
+
context = database.retrieve_doc_chunks(message)
|
292 |
+
|
293 |
+
# Prepare the conversation for the model.
|
294 |
+
conversation = []
|
295 |
+
for human, bot in history:
|
296 |
+
conversation.append({"role": "user", "content": human})
|
297 |
+
conversation.append({"role": "assistant", "content": bot})
|
298 |
+
|
299 |
+
conversation.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
|
300 |
+
conversation.append(
|
301 |
+
{
|
302 |
+
"role": "user",
|
303 |
+
"content": ARGILLA_BOT_TEMPLATE.format(message=message, context=context),
|
304 |
+
}
|
305 |
+
)
|
306 |
+
|
307 |
+
return tokenizer.apply_chat_template(
|
308 |
+
[conversation],
|
309 |
+
tokenize=False,
|
310 |
+
add_generation_prompt=True,
|
311 |
+
)[0]
|
312 |
+
|
313 |
+
|
314 |
+
def chatty(message: str, history: list[tuple[str, str]]) -> Generator[str, None, None]:
|
315 |
+
"""Main function of the app, contains the interaction with the LLM.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
message: Message from the user, the query.
|
319 |
+
history: Previous list of messages from the user and the answers, as a list
|
320 |
+
of tuples with user/assistant messages.
|
321 |
+
|
322 |
+
Yields:
|
323 |
+
The streaming response, it's printed in the interface as it's being received.
|
324 |
+
"""
|
325 |
+
prompt = prepare_input(message, history)
|
326 |
+
|
327 |
+
partial_message = ""
|
328 |
+
for token_stream in client.text_generation(prompt=prompt, **client_kwargs):
|
329 |
+
partial_message += token_stream
|
330 |
+
yield partial_message
|
331 |
+
|
332 |
+
|
333 |
+
|
334 |
+
if __name__ == "__main__":
|
335 |
+
|
336 |
+
import gradio as gr
|
337 |
+
|
338 |
+
gr.ChatInterface(
|
339 |
+
chatty,
|
340 |
+
chatbot=gr.Chatbot(height=600),
|
341 |
+
textbox=gr.Textbox(placeholder="Ask me about the new argilla SDK", container=False, scale=7),
|
342 |
+
title="Argilla SDK Chatbot",
|
343 |
+
description="Ask a question about Argilla SDK",
|
344 |
+
theme="soft",
|
345 |
+
examples=[
|
346 |
+
"How can I connect to an argilla server?",
|
347 |
+
"How can I access a dataset?",
|
348 |
+
"How can I get the current user?"
|
349 |
+
],
|
350 |
+
cache_examples=True,
|
351 |
+
retry_btn=None,
|
352 |
+
).launch()
|