Ali-C137 commited on
Commit
a7cbbd8
·
verified ·
1 Parent(s): 7ba5c3a

Update app_dialogue.py

Browse files
Files changed (1) hide show
  1. app_dialogue.py +70 -32
app_dialogue.py CHANGED
@@ -8,6 +8,11 @@ subprocess.run(
8
  shell=True,
9
  )
10
 
 
 
 
 
 
11
 
12
  import copy
13
  import spaces
@@ -25,7 +30,6 @@ import gradio as gr
25
  from transformers import AutoProcessor, TextIteratorStreamer
26
  from transformers import Idefics2ForConditionalGeneration
27
 
28
-
29
  DEVICE = torch.device("cuda")
30
  MODELS = {
31
  "idefics2-8b-chatty": Idefics2ForConditionalGeneration.from_pretrained(
@@ -40,6 +44,66 @@ PROCESSOR = AutoProcessor.from_pretrained(
40
  # "Ali-C137/idefics2-8b-chatty-yalla",
41
  )
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # Should change this section for the finetuned model
44
  SYSTEM_PROMPT = [
45
  {
@@ -92,12 +156,10 @@ EXAMPLES = [
92
  # BOT_AVATAR = "IDEFICS_logo.png"
93
  BOT_AVATAR = "YALLA_logo.png"
94
 
95
-
96
  # Chatbot utils
97
  def turn_is_pure_media(turn):
98
  return turn[1] is None
99
 
100
-
101
  def load_image_from_url(url):
102
  with urllib.request.urlopen(url) as response:
103
  image_data = response.read()
@@ -105,7 +167,6 @@ def load_image_from_url(url):
105
  image = Image.open(image_stream)
106
  return image
107
 
108
-
109
  def img_to_bytes(image_path):
110
  image = Image.open(image_path).convert(mode='RGB')
111
  buffer = io.BytesIO()
@@ -114,7 +175,6 @@ def img_to_bytes(image_path):
114
  image.close()
115
  return img_bytes
116
 
117
-
118
  def format_user_prompt_with_im_history_and_system_conditioning(
119
  user_prompt, chat_history
120
  ) -> List[Dict[str, Union[List, str]]]:
@@ -179,7 +239,6 @@ def format_user_prompt_with_im_history_and_system_conditioning(
179
 
180
  return resulting_messages, resulting_images
181
 
182
-
183
  def extract_images_from_msg_list(msg_list):
184
  all_images = []
185
  for msg in msg_list:
@@ -188,8 +247,6 @@ def extract_images_from_msg_list(msg_list):
188
  all_images.append(c_)
189
  return all_images
190
 
191
-
192
- # comment this call of spaces.GPU later
193
  @spaces.GPU(duration=60, queue=False)
194
  def model_inference(
195
  user_prompt,
@@ -214,7 +271,6 @@ def model_inference(
214
  )
215
 
216
  # Common parameters to all decoding strategies
217
- # This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
218
  generation_args = {
219
  "max_new_tokens": max_new_tokens,
220
  "repetition_penalty": repetition_penalty,
@@ -233,10 +289,7 @@ def model_inference(
233
  generation_args["top_p"] = top_p
234
 
235
  # Creating model inputs
236
- (
237
- resulting_text,
238
- resulting_images,
239
- ) = format_user_prompt_with_im_history_and_system_conditioning(
240
  user_prompt=user_prompt,
241
  chat_history=chat_history,
242
  )
@@ -249,20 +302,17 @@ def model_inference(
249
  inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
250
  generation_args.update(inputs)
251
 
252
- # # The regular non streaming generation mode
253
- # _ = generation_args.pop("streamer")
254
- # generated_ids = MODELS[model_selector].generate(**generation_args)
255
- # generated_text = PROCESSOR.batch_decode(generated_ids[:, generation_args["input_ids"].size(-1): ], skip_special_tokens=True)[0]
256
- # return generated_text
257
 
258
- # The streaming generation mode
259
  thread = Thread(
260
  target=MODELS[model_selector].generate,
261
  kwargs=generation_args,
262
  )
263
  thread.start()
264
 
265
- print("Start generating")
266
  acc_text = ""
267
  for text_token in streamer:
268
  time.sleep(0.04)
@@ -273,7 +323,6 @@ def model_inference(
273
  print("Success - generated the following text:", acc_text)
274
  print("-----")
275
 
276
-
277
  FEATURES = datasets.Features(
278
  {
279
  "model_selector": datasets.Value("string"),
@@ -287,7 +336,6 @@ FEATURES = datasets.Features(
287
  }
288
  )
289
 
290
-
291
  # Hyper-parameters for generation
292
  max_new_tokens = gr.Slider(
293
  minimum=8,
@@ -337,23 +385,14 @@ top_p = gr.Slider(
337
  info="Higher values is equivalent to sampling more low-probability tokens.",
338
  )
339
 
340
-
341
  chatbot = gr.Chatbot(
342
  label="YALLA-Chatty",
343
  avatar_images=[None, BOT_AVATAR],
344
  height=450,
345
  )
346
 
347
- # with gr.Blocks(
348
- # fill_height=True, # Use this below !?
349
- # css=""".gradio-container .avatar-container {height: 40px width: 40px !important;} #duplicate-button {margin: auto; color: white; background: #f1a139; border-radius: 100vh; margin-top: 2px; margin-bottom: 2px;}""",
350
- # ) as demo:
351
  with gr.Blocks(fill_height=True) as demo:
352
  gr.Markdown("# 🇲🇦 YALLA ")
353
- # gr.Markdown("In this demo you'll be able to chat with YALLA, a variant of [Idefics2-8B](https://huggingface.co/HuggingFaceM4/idefics2-8b-chatty) further fine-tuned on chat datasets, and Moroccan culture 🇲🇦")
354
- # gr.Markdown("If you want to learn more about Idefics2 and its variants, you can check our [blog post](https://huggingface.co/blog/idefics2).")
355
- # gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
356
- # model selector should be set to `visbile=False` ultimately
357
  with gr.Row(elem_id="model_selector_row"):
358
  model_selector = gr.Dropdown(
359
  choices=MODELS.keys(),
@@ -390,7 +429,6 @@ with gr.Blocks(fill_height=True) as demo:
390
  fn=model_inference,
391
  chatbot=chatbot,
392
  examples=EXAMPLES,
393
- # multimodal=True,
394
  multimodal=False,
395
  cache_examples=False,
396
  additional_inputs=[
 
8
  shell=True,
9
  )
10
 
11
+ # Install RAG dependencies
12
+ subprocess.run(
13
+ "pip install langchain sentence-transformers faiss-cpu",
14
+ shell=True,
15
+ )
16
 
17
  import copy
18
  import spaces
 
30
  from transformers import AutoProcessor, TextIteratorStreamer
31
  from transformers import Idefics2ForConditionalGeneration
32
 
 
33
  DEVICE = torch.device("cuda")
34
  MODELS = {
35
  "idefics2-8b-chatty": Idefics2ForConditionalGeneration.from_pretrained(
 
44
  # "Ali-C137/idefics2-8b-chatty-yalla",
45
  )
46
 
47
+ # Load the custom dataset
48
+ knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
49
+
50
+ # Process the documents
51
+ source_docs = [
52
+ Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
53
+ for doc in knowledge_base
54
+ ]
55
+ docs_processed = RecursiveCharacterTextSplitter(chunk_size=500).split_documents(source_docs)[:1000]
56
+
57
+ # Create embeddings and vector store
58
+ embedding_model = HuggingFaceEmbeddings("thenlper/gte-small")
59
+ vectordb = FAISS.from_documents(documents=docs_processed, embedding=embedding_model)
60
+
61
+ class RetrieverTool(Tool):
62
+ name = "retriever"
63
+ description = "Retrieves documents from the knowledge base that have the closest embeddings to the input query."
64
+ inputs = {
65
+ "query": {
66
+ "type": "text",
67
+ "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
68
+ },
69
+ "source": {
70
+ "type": "text",
71
+ "description": "",
72
+ },
73
+ }
74
+ output_type = "text"
75
+
76
+ def __init__(self, vectordb: VectorStore, all_sources: str, **kwargs):
77
+ super().__init__(**kwargs)
78
+ self.vectordb = vectordb
79
+ self.inputs["source"]["description"] = (
80
+ f"The source of the documents to search, as a str representation of a list. Possible values in the list are: {all_sources}. If this argument is not provided, all sources will be searched."
81
+ )
82
+
83
+ def forward(self, query: str, source: str = None) -> str:
84
+ assert isinstance(query, str), "Your search query must be a string"
85
+
86
+ if source:
87
+ if isinstance(source, str) and "[" not in str(source): # if the source is not representing a list
88
+ source = [source]
89
+ source = json.loads(str(source).replace("'", '"'))
90
+
91
+ docs = self.vectordb.similarity_search(query, filter=({"source": source} if source else None), k=3)
92
+
93
+ if len(docs) == 0:
94
+ return "No documents found with this filtering. Try removing the source filter."
95
+ return "Retrieved documents:\n\n" + "\n===Document===\n".join(
96
+ [doc.page_content for doc in docs]
97
+ )
98
+
99
+ from transformers.agents import HfEngine, ReactJsonAgent
100
+
101
+ # Initialize the LLM engine and the agent with the retriever tool
102
+ llm_engine = HfEngine("meta-llama/Meta-Llama-3-8B-Instruct")
103
+ all_sources = list(set([doc.metadata["source"] for doc in docs_processed]))
104
+ retriever_tool = RetrieverTool(vectordb, all_sources)
105
+ agent = ReactJsonAgent(tools=[retriever_tool], llm_engine=llm_engine)
106
+
107
  # Should change this section for the finetuned model
108
  SYSTEM_PROMPT = [
109
  {
 
156
  # BOT_AVATAR = "IDEFICS_logo.png"
157
  BOT_AVATAR = "YALLA_logo.png"
158
 
 
159
  # Chatbot utils
160
  def turn_is_pure_media(turn):
161
  return turn[1] is None
162
 
 
163
  def load_image_from_url(url):
164
  with urllib.request.urlopen(url) as response:
165
  image_data = response.read()
 
167
  image = Image.open(image_stream)
168
  return image
169
 
 
170
  def img_to_bytes(image_path):
171
  image = Image.open(image_path).convert(mode='RGB')
172
  buffer = io.BytesIO()
 
175
  image.close()
176
  return img_bytes
177
 
 
178
  def format_user_prompt_with_im_history_and_system_conditioning(
179
  user_prompt, chat_history
180
  ) -> List[Dict[str, Union[List, str]]]:
 
239
 
240
  return resulting_messages, resulting_images
241
 
 
242
  def extract_images_from_msg_list(msg_list):
243
  all_images = []
244
  for msg in msg_list:
 
247
  all_images.append(c_)
248
  return all_images
249
 
 
 
250
  @spaces.GPU(duration=60, queue=False)
251
  def model_inference(
252
  user_prompt,
 
271
  )
272
 
273
  # Common parameters to all decoding strategies
 
274
  generation_args = {
275
  "max_new_tokens": max_new_tokens,
276
  "repetition_penalty": repetition_penalty,
 
289
  generation_args["top_p"] = top_p
290
 
291
  # Creating model inputs
292
+ resulting_text, resulting_images = format_user_prompt_with_im_history_and_system_conditioning(
 
 
 
293
  user_prompt=user_prompt,
294
  chat_history=chat_history,
295
  )
 
302
  inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
303
  generation_args.update(inputs)
304
 
305
+ # Use the agent to perform RAG
306
+ agent_output = agent.run(user_prompt["text"])
307
+ print("Agent output:", agent_output)
 
 
308
 
309
+ # Stream the generated text
310
  thread = Thread(
311
  target=MODELS[model_selector].generate,
312
  kwargs=generation_args,
313
  )
314
  thread.start()
315
 
 
316
  acc_text = ""
317
  for text_token in streamer:
318
  time.sleep(0.04)
 
323
  print("Success - generated the following text:", acc_text)
324
  print("-----")
325
 
 
326
  FEATURES = datasets.Features(
327
  {
328
  "model_selector": datasets.Value("string"),
 
336
  }
337
  )
338
 
 
339
  # Hyper-parameters for generation
340
  max_new_tokens = gr.Slider(
341
  minimum=8,
 
385
  info="Higher values is equivalent to sampling more low-probability tokens.",
386
  )
387
 
 
388
  chatbot = gr.Chatbot(
389
  label="YALLA-Chatty",
390
  avatar_images=[None, BOT_AVATAR],
391
  height=450,
392
  )
393
 
 
 
 
 
394
  with gr.Blocks(fill_height=True) as demo:
395
  gr.Markdown("# 🇲🇦 YALLA ")
 
 
 
 
396
  with gr.Row(elem_id="model_selector_row"):
397
  model_selector = gr.Dropdown(
398
  choices=MODELS.keys(),
 
429
  fn=model_inference,
430
  chatbot=chatbot,
431
  examples=EXAMPLES,
 
432
  multimodal=False,
433
  cache_examples=False,
434
  additional_inputs=[