Spaces:
Runtime error
Runtime error
Update app_dialogue.py
Browse files- app_dialogue.py +70 -32
app_dialogue.py
CHANGED
@@ -8,6 +8,11 @@ subprocess.run(
|
|
8 |
shell=True,
|
9 |
)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
import copy
|
13 |
import spaces
|
@@ -25,7 +30,6 @@ import gradio as gr
|
|
25 |
from transformers import AutoProcessor, TextIteratorStreamer
|
26 |
from transformers import Idefics2ForConditionalGeneration
|
27 |
|
28 |
-
|
29 |
DEVICE = torch.device("cuda")
|
30 |
MODELS = {
|
31 |
"idefics2-8b-chatty": Idefics2ForConditionalGeneration.from_pretrained(
|
@@ -40,6 +44,66 @@ PROCESSOR = AutoProcessor.from_pretrained(
|
|
40 |
# "Ali-C137/idefics2-8b-chatty-yalla",
|
41 |
)
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
# Should change this section for the finetuned model
|
44 |
SYSTEM_PROMPT = [
|
45 |
{
|
@@ -92,12 +156,10 @@ EXAMPLES = [
|
|
92 |
# BOT_AVATAR = "IDEFICS_logo.png"
|
93 |
BOT_AVATAR = "YALLA_logo.png"
|
94 |
|
95 |
-
|
96 |
# Chatbot utils
|
97 |
def turn_is_pure_media(turn):
|
98 |
return turn[1] is None
|
99 |
|
100 |
-
|
101 |
def load_image_from_url(url):
|
102 |
with urllib.request.urlopen(url) as response:
|
103 |
image_data = response.read()
|
@@ -105,7 +167,6 @@ def load_image_from_url(url):
|
|
105 |
image = Image.open(image_stream)
|
106 |
return image
|
107 |
|
108 |
-
|
109 |
def img_to_bytes(image_path):
|
110 |
image = Image.open(image_path).convert(mode='RGB')
|
111 |
buffer = io.BytesIO()
|
@@ -114,7 +175,6 @@ def img_to_bytes(image_path):
|
|
114 |
image.close()
|
115 |
return img_bytes
|
116 |
|
117 |
-
|
118 |
def format_user_prompt_with_im_history_and_system_conditioning(
|
119 |
user_prompt, chat_history
|
120 |
) -> List[Dict[str, Union[List, str]]]:
|
@@ -179,7 +239,6 @@ def format_user_prompt_with_im_history_and_system_conditioning(
|
|
179 |
|
180 |
return resulting_messages, resulting_images
|
181 |
|
182 |
-
|
183 |
def extract_images_from_msg_list(msg_list):
|
184 |
all_images = []
|
185 |
for msg in msg_list:
|
@@ -188,8 +247,6 @@ def extract_images_from_msg_list(msg_list):
|
|
188 |
all_images.append(c_)
|
189 |
return all_images
|
190 |
|
191 |
-
|
192 |
-
# comment this call of spaces.GPU later
|
193 |
@spaces.GPU(duration=60, queue=False)
|
194 |
def model_inference(
|
195 |
user_prompt,
|
@@ -214,7 +271,6 @@ def model_inference(
|
|
214 |
)
|
215 |
|
216 |
# Common parameters to all decoding strategies
|
217 |
-
# This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
|
218 |
generation_args = {
|
219 |
"max_new_tokens": max_new_tokens,
|
220 |
"repetition_penalty": repetition_penalty,
|
@@ -233,10 +289,7 @@ def model_inference(
|
|
233 |
generation_args["top_p"] = top_p
|
234 |
|
235 |
# Creating model inputs
|
236 |
-
(
|
237 |
-
resulting_text,
|
238 |
-
resulting_images,
|
239 |
-
) = format_user_prompt_with_im_history_and_system_conditioning(
|
240 |
user_prompt=user_prompt,
|
241 |
chat_history=chat_history,
|
242 |
)
|
@@ -249,20 +302,17 @@ def model_inference(
|
|
249 |
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
250 |
generation_args.update(inputs)
|
251 |
|
252 |
-
#
|
253 |
-
|
254 |
-
|
255 |
-
# generated_text = PROCESSOR.batch_decode(generated_ids[:, generation_args["input_ids"].size(-1): ], skip_special_tokens=True)[0]
|
256 |
-
# return generated_text
|
257 |
|
258 |
-
#
|
259 |
thread = Thread(
|
260 |
target=MODELS[model_selector].generate,
|
261 |
kwargs=generation_args,
|
262 |
)
|
263 |
thread.start()
|
264 |
|
265 |
-
print("Start generating")
|
266 |
acc_text = ""
|
267 |
for text_token in streamer:
|
268 |
time.sleep(0.04)
|
@@ -273,7 +323,6 @@ def model_inference(
|
|
273 |
print("Success - generated the following text:", acc_text)
|
274 |
print("-----")
|
275 |
|
276 |
-
|
277 |
FEATURES = datasets.Features(
|
278 |
{
|
279 |
"model_selector": datasets.Value("string"),
|
@@ -287,7 +336,6 @@ FEATURES = datasets.Features(
|
|
287 |
}
|
288 |
)
|
289 |
|
290 |
-
|
291 |
# Hyper-parameters for generation
|
292 |
max_new_tokens = gr.Slider(
|
293 |
minimum=8,
|
@@ -337,23 +385,14 @@ top_p = gr.Slider(
|
|
337 |
info="Higher values is equivalent to sampling more low-probability tokens.",
|
338 |
)
|
339 |
|
340 |
-
|
341 |
chatbot = gr.Chatbot(
|
342 |
label="YALLA-Chatty",
|
343 |
avatar_images=[None, BOT_AVATAR],
|
344 |
height=450,
|
345 |
)
|
346 |
|
347 |
-
# with gr.Blocks(
|
348 |
-
# fill_height=True, # Use this below !?
|
349 |
-
# css=""".gradio-container .avatar-container {height: 40px width: 40px !important;} #duplicate-button {margin: auto; color: white; background: #f1a139; border-radius: 100vh; margin-top: 2px; margin-bottom: 2px;}""",
|
350 |
-
# ) as demo:
|
351 |
with gr.Blocks(fill_height=True) as demo:
|
352 |
gr.Markdown("# 🇲🇦 YALLA ")
|
353 |
-
# gr.Markdown("In this demo you'll be able to chat with YALLA, a variant of [Idefics2-8B](https://huggingface.co/HuggingFaceM4/idefics2-8b-chatty) further fine-tuned on chat datasets, and Moroccan culture 🇲🇦")
|
354 |
-
# gr.Markdown("If you want to learn more about Idefics2 and its variants, you can check our [blog post](https://huggingface.co/blog/idefics2).")
|
355 |
-
# gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
|
356 |
-
# model selector should be set to `visbile=False` ultimately
|
357 |
with gr.Row(elem_id="model_selector_row"):
|
358 |
model_selector = gr.Dropdown(
|
359 |
choices=MODELS.keys(),
|
@@ -390,7 +429,6 @@ with gr.Blocks(fill_height=True) as demo:
|
|
390 |
fn=model_inference,
|
391 |
chatbot=chatbot,
|
392 |
examples=EXAMPLES,
|
393 |
-
# multimodal=True,
|
394 |
multimodal=False,
|
395 |
cache_examples=False,
|
396 |
additional_inputs=[
|
|
|
8 |
shell=True,
|
9 |
)
|
10 |
|
11 |
+
# Install RAG dependencies
|
12 |
+
subprocess.run(
|
13 |
+
"pip install langchain sentence-transformers faiss-cpu",
|
14 |
+
shell=True,
|
15 |
+
)
|
16 |
|
17 |
import copy
|
18 |
import spaces
|
|
|
30 |
from transformers import AutoProcessor, TextIteratorStreamer
|
31 |
from transformers import Idefics2ForConditionalGeneration
|
32 |
|
|
|
33 |
DEVICE = torch.device("cuda")
|
34 |
MODELS = {
|
35 |
"idefics2-8b-chatty": Idefics2ForConditionalGeneration.from_pretrained(
|
|
|
44 |
# "Ali-C137/idefics2-8b-chatty-yalla",
|
45 |
)
|
46 |
|
47 |
+
# Load the custom dataset
|
48 |
+
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
|
49 |
+
|
50 |
+
# Process the documents
|
51 |
+
source_docs = [
|
52 |
+
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
|
53 |
+
for doc in knowledge_base
|
54 |
+
]
|
55 |
+
docs_processed = RecursiveCharacterTextSplitter(chunk_size=500).split_documents(source_docs)[:1000]
|
56 |
+
|
57 |
+
# Create embeddings and vector store
|
58 |
+
embedding_model = HuggingFaceEmbeddings("thenlper/gte-small")
|
59 |
+
vectordb = FAISS.from_documents(documents=docs_processed, embedding=embedding_model)
|
60 |
+
|
61 |
+
class RetrieverTool(Tool):
|
62 |
+
name = "retriever"
|
63 |
+
description = "Retrieves documents from the knowledge base that have the closest embeddings to the input query."
|
64 |
+
inputs = {
|
65 |
+
"query": {
|
66 |
+
"type": "text",
|
67 |
+
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
|
68 |
+
},
|
69 |
+
"source": {
|
70 |
+
"type": "text",
|
71 |
+
"description": "",
|
72 |
+
},
|
73 |
+
}
|
74 |
+
output_type = "text"
|
75 |
+
|
76 |
+
def __init__(self, vectordb: VectorStore, all_sources: str, **kwargs):
|
77 |
+
super().__init__(**kwargs)
|
78 |
+
self.vectordb = vectordb
|
79 |
+
self.inputs["source"]["description"] = (
|
80 |
+
f"The source of the documents to search, as a str representation of a list. Possible values in the list are: {all_sources}. If this argument is not provided, all sources will be searched."
|
81 |
+
)
|
82 |
+
|
83 |
+
def forward(self, query: str, source: str = None) -> str:
|
84 |
+
assert isinstance(query, str), "Your search query must be a string"
|
85 |
+
|
86 |
+
if source:
|
87 |
+
if isinstance(source, str) and "[" not in str(source): # if the source is not representing a list
|
88 |
+
source = [source]
|
89 |
+
source = json.loads(str(source).replace("'", '"'))
|
90 |
+
|
91 |
+
docs = self.vectordb.similarity_search(query, filter=({"source": source} if source else None), k=3)
|
92 |
+
|
93 |
+
if len(docs) == 0:
|
94 |
+
return "No documents found with this filtering. Try removing the source filter."
|
95 |
+
return "Retrieved documents:\n\n" + "\n===Document===\n".join(
|
96 |
+
[doc.page_content for doc in docs]
|
97 |
+
)
|
98 |
+
|
99 |
+
from transformers.agents import HfEngine, ReactJsonAgent
|
100 |
+
|
101 |
+
# Initialize the LLM engine and the agent with the retriever tool
|
102 |
+
llm_engine = HfEngine("meta-llama/Meta-Llama-3-8B-Instruct")
|
103 |
+
all_sources = list(set([doc.metadata["source"] for doc in docs_processed]))
|
104 |
+
retriever_tool = RetrieverTool(vectordb, all_sources)
|
105 |
+
agent = ReactJsonAgent(tools=[retriever_tool], llm_engine=llm_engine)
|
106 |
+
|
107 |
# Should change this section for the finetuned model
|
108 |
SYSTEM_PROMPT = [
|
109 |
{
|
|
|
156 |
# BOT_AVATAR = "IDEFICS_logo.png"
|
157 |
BOT_AVATAR = "YALLA_logo.png"
|
158 |
|
|
|
159 |
# Chatbot utils
|
160 |
def turn_is_pure_media(turn):
|
161 |
return turn[1] is None
|
162 |
|
|
|
163 |
def load_image_from_url(url):
|
164 |
with urllib.request.urlopen(url) as response:
|
165 |
image_data = response.read()
|
|
|
167 |
image = Image.open(image_stream)
|
168 |
return image
|
169 |
|
|
|
170 |
def img_to_bytes(image_path):
|
171 |
image = Image.open(image_path).convert(mode='RGB')
|
172 |
buffer = io.BytesIO()
|
|
|
175 |
image.close()
|
176 |
return img_bytes
|
177 |
|
|
|
178 |
def format_user_prompt_with_im_history_and_system_conditioning(
|
179 |
user_prompt, chat_history
|
180 |
) -> List[Dict[str, Union[List, str]]]:
|
|
|
239 |
|
240 |
return resulting_messages, resulting_images
|
241 |
|
|
|
242 |
def extract_images_from_msg_list(msg_list):
|
243 |
all_images = []
|
244 |
for msg in msg_list:
|
|
|
247 |
all_images.append(c_)
|
248 |
return all_images
|
249 |
|
|
|
|
|
250 |
@spaces.GPU(duration=60, queue=False)
|
251 |
def model_inference(
|
252 |
user_prompt,
|
|
|
271 |
)
|
272 |
|
273 |
# Common parameters to all decoding strategies
|
|
|
274 |
generation_args = {
|
275 |
"max_new_tokens": max_new_tokens,
|
276 |
"repetition_penalty": repetition_penalty,
|
|
|
289 |
generation_args["top_p"] = top_p
|
290 |
|
291 |
# Creating model inputs
|
292 |
+
resulting_text, resulting_images = format_user_prompt_with_im_history_and_system_conditioning(
|
|
|
|
|
|
|
293 |
user_prompt=user_prompt,
|
294 |
chat_history=chat_history,
|
295 |
)
|
|
|
302 |
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
303 |
generation_args.update(inputs)
|
304 |
|
305 |
+
# Use the agent to perform RAG
|
306 |
+
agent_output = agent.run(user_prompt["text"])
|
307 |
+
print("Agent output:", agent_output)
|
|
|
|
|
308 |
|
309 |
+
# Stream the generated text
|
310 |
thread = Thread(
|
311 |
target=MODELS[model_selector].generate,
|
312 |
kwargs=generation_args,
|
313 |
)
|
314 |
thread.start()
|
315 |
|
|
|
316 |
acc_text = ""
|
317 |
for text_token in streamer:
|
318 |
time.sleep(0.04)
|
|
|
323 |
print("Success - generated the following text:", acc_text)
|
324 |
print("-----")
|
325 |
|
|
|
326 |
FEATURES = datasets.Features(
|
327 |
{
|
328 |
"model_selector": datasets.Value("string"),
|
|
|
336 |
}
|
337 |
)
|
338 |
|
|
|
339 |
# Hyper-parameters for generation
|
340 |
max_new_tokens = gr.Slider(
|
341 |
minimum=8,
|
|
|
385 |
info="Higher values is equivalent to sampling more low-probability tokens.",
|
386 |
)
|
387 |
|
|
|
388 |
chatbot = gr.Chatbot(
|
389 |
label="YALLA-Chatty",
|
390 |
avatar_images=[None, BOT_AVATAR],
|
391 |
height=450,
|
392 |
)
|
393 |
|
|
|
|
|
|
|
|
|
394 |
with gr.Blocks(fill_height=True) as demo:
|
395 |
gr.Markdown("# 🇲🇦 YALLA ")
|
|
|
|
|
|
|
|
|
396 |
with gr.Row(elem_id="model_selector_row"):
|
397 |
model_selector = gr.Dropdown(
|
398 |
choices=MODELS.keys(),
|
|
|
429 |
fn=model_inference,
|
430 |
chatbot=chatbot,
|
431 |
examples=EXAMPLES,
|
|
|
432 |
multimodal=False,
|
433 |
cache_examples=False,
|
434 |
additional_inputs=[
|