hoduyquocbao commited on
Commit
c6bca05
·
1 Parent(s): 0abfd6d

fix errors

Browse files
Files changed (1) hide show
  1. app.py +170 -170
app.py CHANGED
@@ -23,56 +23,58 @@ from datasets import load_dataset
23
  from peft import LoraConfig, get_peft_model
24
  import time
25
 
26
- # ---------------------------- Configuration ---------------------------- #
27
 
28
  DESCRIPTION = """\
29
- # Llama 3.2 3B Instruct with Advanced Features
30
 
31
- Llama 3.2 3B is the latest version from Meta for open language models.
32
- This demo showcases [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction-following.
33
- For more details, please see [our blog post](https://huggingface.co/blog/llama32).
34
  """
35
 
36
- MAX_MAX_NEW_TOKENS = 2048 # Maximum tokens to generate
37
- DEFAULT_MAX_NEW_TOKENS = 1024 # Default tokens to generate
38
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000")) # Max input token length
39
 
40
- # Determine device (GPU if available, else CPU)
41
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42
 
43
- model_id = "meta-llama/Llama-3.2-3B-Instruct" # Model ID
44
  tokenizer = AutoTokenizer.from_pretrained(model_id)
45
  model = AutoModelForCausalLM.from_pretrained(
46
  model_id,
47
  device_map="auto",
48
- torch_dtype=torch.float16, # Use float16 for compatibility with fp16=True
49
  )
50
  model.to(device)
51
  model.eval()
52
 
53
- # Initialize sentiment analysis pipeline on GPU if available
54
  sentiment_pipeline = pipeline(
55
  "sentiment-analysis",
56
  model="nlptown/bert-base-multilingual-uncased-sentiment",
57
  device=0 if torch.cuda.is_available() else -1
58
  )
59
 
60
- # ---------------------------- Function Definitions ---------------------------- #
61
 
62
  @lru_cache(maxsize=128)
63
  def extract_text_from_webpage(html_content: str) -> str:
64
- """Extract visible text from HTML content using BeautifulSoup."""
65
  soup = BeautifulSoup(html_content, "html.parser")
 
66
  for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
67
  tag.extract()
 
68
  visible_text = soup.get_text(separator=' ', strip=True)
69
  return visible_text
70
 
71
  def search(query: str) -> List[Dict[str, Any]]:
72
- """Perform a Google search and return results."""
73
  term = query
74
  all_results = []
75
- max_chars_per_page = 8000 # Max characters per page
76
  headers = {
77
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
78
  }
@@ -81,15 +83,15 @@ def search(query: str) -> List[Dict[str, Any]]:
81
  resp = session.get(
82
  url="https://www.google.com/search",
83
  headers=headers,
84
- params={"q": term, "num": 4}, # 4 results per page
85
  timeout=5,
86
- verify=False, # Skip SSL verification
87
  )
88
  resp.raise_for_status()
89
  soup = BeautifulSoup(resp.text, "html.parser")
90
- result_blocks = soup.find_all("div", attrs={"class": "g"})
91
  for result in result_blocks:
92
- link_tag = result.find("a", href=True)
93
  if link_tag and 'href' in link_tag.attrs:
94
  link = link_tag["href"]
95
  try:
@@ -102,18 +104,18 @@ def search(query: str) -> List[Dict[str, Any]]:
102
  webpage.raise_for_status()
103
  visible_text = extract_text_from_webpage(webpage.text)
104
  if len(visible_text) > max_chars_per_page:
105
- visible_text = visible_text[:max_chars_per_page]
106
  all_results.append({"link": link, "text": visible_text})
107
  except requests.exceptions.RequestException:
108
- all_results.append({"link": link, "text": "Could not retrieve content."})
109
  except requests.exceptions.RequestException as e:
110
- all_results.append({"link": "N/A", "text": "Could not perform search."})
111
  return all_results
112
 
113
  def summarize_text(text: str, max_length: int = 150) -> str:
114
- """Summarize text using the Llama model."""
115
  conversation = [
116
- {"role": "user", "content": f"Please summarize the following text: {text}"}
117
  ]
118
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
119
  input_ids = input_ids.to(device)
@@ -136,33 +138,33 @@ def summarize_text(text: str, max_length: int = 150) -> str:
136
  return summary
137
 
138
  def analyze_sentiment(text: str) -> str:
139
- """Analyze sentiment of the text using a sentiment analysis model."""
140
  result = sentiment_pipeline(text)
141
  sentiment = result[0]['label']
142
  score = result[0]['score']
143
- return f"🟢 **Sentiment**: {sentiment} (Score: {score:.2f})"
144
 
145
  def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
146
  """
147
- Generate a response using the Llama model in streaming mode.
148
  """
149
- # Build conversation history
150
  conversation = []
151
  for user, assistant in chat_history:
152
  conversation.extend([
153
  {"role": "user", "content": user},
154
  {"role": "assistant", "content": assistant},
155
  ])
156
- conversation.append({"role": "user", "content": prompt}) # Add user's message
157
 
158
- # Prepare input_ids from tokenizer
159
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
160
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
161
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # Truncate input if too long
162
- gr.Warning(f"Truncated conversation due to exceeding {MAX_INPUT_TOKEN_LENGTH} tokens.")
163
- input_ids = input_ids.to(device)
164
 
165
- # Initialize streamer for real-time text generation
166
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
167
  generate_kwargs = {
168
  "input_ids": input_ids,
@@ -175,10 +177,10 @@ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_
175
  "num_beams": 1,
176
  "repetition_penalty": repetition_penalty,
177
  }
178
- t = Thread(target=model.generate, kwargs=generate_kwargs) # Create thread for text generation
179
  t.start()
180
 
181
- # Stream generated text
182
  outputs = []
183
  for text in streamer:
184
  outputs.append(text)
@@ -187,9 +189,9 @@ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_
187
  @lru_cache(maxsize=128)
188
  def process_query(query: str) -> Dict[str, Any]:
189
  """
190
- Determine which function to call based on the user's query.
191
  """
192
- # Define keywords or patterns to identify functions
193
  web_search_keywords = ["search", "find", "lookup", "google"]
194
  general_query_keywords = ["explain", "describe", "tell me about", "what is", "how to"]
195
  summarize_keywords = ["summarize", "summarise", "brief", "short"]
@@ -224,60 +226,60 @@ def process_query(query: str) -> Dict[str, Any]:
224
 
225
  def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
226
  """
227
- Execute the appropriate function based on the function call.
228
  """
229
  function_name = function_call["name"]
230
  arguments = function_call["arguments"]
231
 
232
  if function_name == "web_search":
233
  query = arguments["query"]
234
- yield "🔍 Performing web search..."
235
  web_results = search(query)
236
  if not web_results:
237
- yield "⚠️ No results found."
238
  return
239
- # Summarize search results
240
- web_summary = '\n\n'.join([f"🔗 **Link**: {res['link']}\n📝 **Description**: {res['text']}" for res in web_results if res["text"] != "Could not retrieve content."])
241
  if not web_summary:
242
- web_summary = "⚠️ Could not retrieve content from search results."
243
 
244
- # Return search results to user
245
- yield "📄 **Search Results:**\n" + web_summary
246
 
247
  elif function_name == "summarize_query":
248
- # When user requests summarization, perform search and then summarize
249
  query = arguments["prompt"]
250
- yield "🔍 Performing search for summarization..."
251
  web_results = search(query)
252
  if not web_results:
253
- yield "⚠️ No results found to summarize."
254
  return
255
- # Combine content from search results for summarization
256
- combined_text = ' '.join([res['text'] for res in web_results if res['text'] != "Could not retrieve content."])
257
  if not combined_text:
258
- yield "⚠️ No content available to summarize."
259
  return
260
- # Summarize the combined content
261
- yield "📝 Summarizing information..."
262
  summary = summarize_text(combined_text)
263
- yield "📄 **Summary:**\n" + summary
264
 
265
  elif function_name == "sentiment_analysis":
266
  prompt_text = arguments["prompt"]
267
- yield "📊 Analyzing sentiment..."
268
  sentiment = analyze_sentiment(prompt_text)
269
  yield sentiment
270
 
271
  elif function_name == "train_model":
272
  prompt_text = arguments["prompt"]
273
- yield "📊 Training the model..."
274
  training_result = run_training()
275
  yield training_result
276
 
277
  elif function_name in ["general_query", "hard_query"]:
278
  prompt_text = arguments["prompt"]
279
- yield "🤖 Generating response..."
280
- # Generate response using the Llama model
281
  response_generator = generate_response(
282
  prompt=prompt_text,
283
  chat_history=chat_history,
@@ -291,24 +293,24 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
291
  yield response
292
 
293
  else:
294
- yield "⚠️ Unrecognized function call."
295
 
296
- # ---------------------------- Training Setup ---------------------------- #
297
 
298
- # Checkpoint directory
299
  CHECKPOINT_DIR = "./checkpoints"
300
  if not os.path.exists(CHECKPOINT_DIR):
301
  os.makedirs(CHECKPOINT_DIR)
302
 
303
- # Load Dataset (CPU)
304
  dataset = load_dataset('vntc/wiki-mini-corpus')
305
 
306
- # Split Dataset into train and validation (CPU)
307
  split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
308
  train_dataset = split_dataset['train']
309
  validation_dataset = split_dataset['test']
310
 
311
- # Text Preprocessing (CPU)
312
  def preprocess_function(examples):
313
  passages = [passage.lower().strip() for passage in examples['passage']]
314
  return {'passage': passages}
@@ -316,7 +318,7 @@ def preprocess_function(examples):
316
  processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
317
  processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
318
 
319
- # Ensure tokenizer has pad_token
320
  if tokenizer.pad_token is None:
321
  tokenizer.pad_token = tokenizer.eos_token
322
 
@@ -332,7 +334,7 @@ def tokenize_function(examples):
332
  tokenized_train = processed_train.map(tokenize_function, batched=True)
333
  tokenized_validation = processed_validation.map(tokenize_function, batched=True)
334
 
335
- # Add 'labels' field (CPU)
336
  def add_labels(examples):
337
  examples['labels'] = examples['input_ids'].copy()
338
  return examples
@@ -340,31 +342,29 @@ def add_labels(examples):
340
  tokenized_train = tokenized_train.map(add_labels, batched=True)
341
  tokenized_validation = tokenized_validation.map(add_labels, batched=True)
342
 
343
- # Remove unnecessary columns (CPU)
344
  tokenized_train = tokenized_train.remove_columns(['passage'])
345
  tokenized_validation = tokenized_validation.remove_columns(['passage'])
346
 
347
- # Set format for PyTorch (CPU)
348
  tokenized_train.set_format('torch')
349
  tokenized_validation.set_format('torch')
350
 
351
- # Create DatasetDict (CPU)
352
  final_dataset = {
353
  'train': tokenized_train,
354
  'validation': tokenized_validation
355
  }
356
 
357
- # Define TrainerCallback to Save Checkpoints Faster
358
  class SaveCheckpointCallback(TrainerCallback):
359
- def on_step_end(self, args, state, control, **kwargs):
360
- if state.global_step % args.save_steps == 0 and state.global_step != 0:
361
- checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
362
- print(f"Saving checkpoint at: {checkpoint_path}")
363
- trainer = kwargs['trainer'] # Access trainer from kwargs
364
- trainer.save_model(checkpoint_path)
365
- return control # Return current control object
366
-
367
- # Load pretrained model with LoRA
368
  pretrained = AutoModelForCausalLM.from_pretrained(
369
  model_id,
370
  device_map="auto",
@@ -374,30 +374,30 @@ pretrained = AutoModelForCausalLM.from_pretrained(
374
 
375
  data_collator = DataCollatorForLanguageModeling(
376
  tokenizer=tokenizer,
377
- mlm=False, # Causal LM
378
  pad_to_multiple_of=8
379
  )
380
 
381
  def get_step_done() -> int:
382
  """
383
- Get the number of training steps completed from the latest checkpoint.
384
 
385
  Returns:
386
- int: Number of steps completed. Returns 0 if no checkpoint is found.
387
  """
388
  checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')]
389
  if not checkpoints:
390
  return 0
391
  try:
392
- # Find the latest checkpoint based on step number
393
  latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[1]))
394
  step_done = int(latest_checkpoint.split('-')[1])
395
  return step_done
396
  except (IndexError, ValueError) as e:
397
- print(f"Error parsing checkpoint name: {e}")
398
  return 0
399
 
400
- # Load and Configure LoRA (GPU)
401
  lora_config = LoraConfig(
402
  r=8,
403
  lora_alpha=32,
@@ -412,34 +412,34 @@ print(pretrained_model)
412
  @spaces.GPU(duration=30, queue=False)
413
  def run_training() -> str:
414
  """
415
- Train the model using GPU with time constraints.
416
 
417
  Returns:
418
- str: Training result message.
419
  """
420
 
421
- # TrainingArguments Configuration (GPU)
422
  training_args = TrainingArguments(
423
  output_dir=CHECKPOINT_DIR,
424
  per_device_train_batch_size=4,
425
  per_device_eval_batch_size=4,
426
  gradient_accumulation_steps=8,
427
  num_train_epochs=3,
428
- max_steps=300, # Total training steps
429
  learning_rate=3e-4,
430
  weight_decay=0.01,
431
- logging_steps=1, # Log every step
432
- eval_strategy="steps", # Evaluate every few steps
433
- eval_steps=5, # Evaluate every 5 steps
434
- save_strategy="steps", # Save checkpoint every few steps
435
- save_steps=5, # Save every 5 steps
436
- save_total_limit=5, # Limit number of saved checkpoints
437
- fp16=True,
438
  report_to="none",
439
  load_best_model_at_end=True,
440
  )
441
 
442
- # Initialize Trainer (GPU)
443
  trainer = Trainer(
444
  model=pretrained_model,
445
  args=training_args,
@@ -447,50 +447,50 @@ def run_training() -> str:
447
  eval_dataset=final_dataset['validation'],
448
  tokenizer=tokenizer,
449
  data_collator=data_collator,
450
- callbacks=[SaveCheckpointCallback()], # Add callback
451
  )
452
 
453
- # Check for existing checkpoint
454
  steps_done = get_step_done()
455
  if steps_done > 0:
456
- # Determine the latest checkpoint based on step number
457
  latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
458
  if os.path.exists(latest_checkpoint):
459
- print(f"Resuming training from checkpoint: {latest_checkpoint}")
460
  trainer.train(resume_from_checkpoint=latest_checkpoint)
461
  else:
462
- print(f"Checkpoint {latest_checkpoint} does not exist. Starting training from scratch.")
463
  trainer.train()
464
  else:
465
  trainer.train()
466
 
467
- # Save checkpoint after training
468
  trainer.save_model(CHECKPOINT_DIR)
469
- return "Training completed or resumed from checkpoint."
470
 
471
- # Automatic Function to Call Training Repeatedly
472
  @spaces.GPU(duration=30, queue=False)
473
  def continuous_training(total_steps=300, steps_per_call=50):
474
  """
475
- Automatically call `run_training` to complete the training process.
476
 
477
  Args:
478
- total_steps (int): Desired total training steps.
479
- steps_per_call (int): Training steps per function call.
480
  """
481
  steps_done = get_step_done()
482
  while steps_done < total_steps:
483
  remaining_steps = total_steps - steps_done
484
  current_steps = min(steps_per_call, remaining_steps)
485
- print(f"Starting training for {current_steps} steps.")
486
 
487
- # Update TrainingArguments for current_steps
488
  training_args = TrainingArguments(
489
  output_dir=CHECKPOINT_DIR,
490
  per_device_train_batch_size=4,
491
  per_device_eval_batch_size=4,
492
  gradient_accumulation_steps=8,
493
- num_train_epochs=1, # Train for one epoch
494
  max_steps=current_steps,
495
  learning_rate=3e-4,
496
  weight_decay=0.01,
@@ -505,7 +505,7 @@ def continuous_training(total_steps=300, steps_per_call=50):
505
  load_best_model_at_end=True,
506
  )
507
 
508
- # Initialize Trainer with updated TrainingArguments
509
  trainer = Trainer(
510
  model=pretrained_model,
511
  args=training_args,
@@ -516,30 +516,30 @@ def continuous_training(total_steps=300, steps_per_call=50):
516
  callbacks=[SaveCheckpointCallback()],
517
  )
518
 
519
- # Resume training from latest checkpoint
520
  if steps_done > 0:
521
  latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
522
  if os.path.exists(latest_checkpoint):
523
- print(f"Resuming training from checkpoint: {latest_checkpoint}")
524
  trainer.train(resume_from_checkpoint=latest_checkpoint)
525
  else:
526
- print(f"Checkpoint {latest_checkpoint} does not exist. Starting training from scratch.")
527
  trainer.train()
528
  else:
529
  trainer.train()
530
 
531
  steps_done = get_step_done()
532
- print(f"Trained {steps_done} / {total_steps} steps.")
533
 
534
- # Check if desired steps are achieved
535
  if steps_done >= total_steps:
536
- print("Completed the entire training process.")
537
  break
538
 
539
- # Wait before the next training call
540
- time.sleep(2) # Adjust wait time as needed
541
 
542
- # ---------------------------- Gradio Interface ---------------------------- #
543
 
544
  @spaces.GPU(duration=30, queue=False)
545
  def generate(
@@ -552,29 +552,29 @@ def generate(
552
  repetition_penalty: float = 1.2,
553
  ) -> Iterator[str]:
554
  """
555
- Main function to handle user input and generate responses.
556
  """
557
- # Notify about query analysis
558
- yield "🔍 Analyzing your query..."
559
 
560
- # Determine which function to call based on user's message
561
  function_call = process_query(message)
562
 
563
- # Notify about the selected function
564
  if function_call["name"] == "web_search":
565
- yield "🛠️ Selected function: Web Search."
566
  elif function_call["name"] == "summarize_query":
567
- yield "🛠️ Selected function: Text Summarization."
568
  elif function_call["name"] == "sentiment_analysis":
569
- yield "🛠️ Selected function: Sentiment Analysis."
570
  elif function_call["name"] in ["general_query", "hard_query"]:
571
- yield "🛠️ Selected function: Answering Questions."
572
  elif function_call["name"] == "train_model":
573
- yield "🛠️ Selected function: Model Training."
574
  else:
575
- yield "⚠️ Could not determine an appropriate function."
576
 
577
- # Execute the function call and generate responses
578
  response_iterator = handle_functions(
579
  function_call=function_call,
580
  prompt=message,
@@ -589,40 +589,40 @@ def generate(
589
  for response in response_iterator:
590
  yield response
591
 
592
- # Define examples to guide users
593
  EXAMPLES = [
594
- ["Hello! How are you?"],
595
- ["Can you briefly explain the Python programming language?"],
596
- ["Explain the plot of Cinderella in one sentence."],
597
- ["How many hours does a man need to eat a helicopter?"],
598
- ["Write a 100-word article on 'Benefits of Open Source in AI Research'"],
599
- ["Search and provide me with the latest news on renewable energy."],
600
- ["Find information about the Great Barrier Reef coral reefs."],
601
- ["Summarize information about artificial intelligence."],
602
- ["Analyze the sentiment of the following text: I am very happy to meet you today!"],
603
- ["Train the model!"],
604
  ]
605
 
606
- # Configure Gradio chat interface with enhanced UI
607
  chat_interface = gr.ChatInterface(
608
- fn=generate, # Function called on user interaction
609
  additional_inputs=[
610
  gr.Slider(
611
- label="Max New Tokens",
612
  minimum=1,
613
  maximum=MAX_MAX_NEW_TOKENS,
614
  step=1,
615
  value=DEFAULT_MAX_NEW_TOKENS,
616
  ),
617
  gr.Slider(
618
- label="Temperature",
619
  minimum=0.1,
620
  maximum=4.0,
621
  step=0.1,
622
  value=0.6,
623
  ),
624
  gr.Slider(
625
- label="Top-p (Nucleus Sampling)",
626
  minimum=0.05,
627
  maximum=1.0,
628
  step=0.05,
@@ -636,45 +636,45 @@ chat_interface = gr.ChatInterface(
636
  value=50,
637
  ),
638
  gr.Slider(
639
- label="Repetition Penalty",
640
  minimum=1.0,
641
  maximum=2.0,
642
  step=0.05,
643
  value=1.2,
644
  ),
645
  ],
646
- stop_btn=None, # No stop button
647
- examples=EXAMPLES, # Display examples to users
648
- cache_examples=False, # Do not cache examples
649
  title="🤖 OpenGPT-4o Chatbot",
650
- description="A powerful AI assistant using the local Llama-3.2 model with web search, text summarization, and sentiment analysis functionalities.",
651
- theme="default", # Customize theme as needed
652
  )
653
 
654
- # Create the main Gradio interface with custom CSS
655
  with gr.Blocks(css="""
656
  .gradio-container {
657
- background-color: #f0f2f5; /* Light background color */
658
  }
659
  .gradio-container h1 {
660
- color: #4a90e2; /* Blue color for title */
661
  }
662
  .gradio-container .gr-button {
663
- background-color: #4a90e2; /* Blue color for buttons */
664
- color: white; /* White text on buttons */
665
  }
666
  .gradio-container .gr-slider__label {
667
- color: #333333; /* Dark text for slider labels */
668
  }
669
  .gradio-container .gr-chatbot {
670
- border: 2px solid #4a90e2; /* Blue border for chatbot */
671
- border-radius: 10px; /* Rounded corners for chatbot */
672
- padding: 10px; /* Inner padding for chatbot */
673
- background-color: #ffffff; /* White background for chatbot */
674
  }
675
  """, fill_height=True) as demo:
676
- gr.Markdown(DESCRIPTION) # Display description
677
- chat_interface.render() # Render chat interface
678
 
679
  if __name__ == "__main__":
680
- demo.queue(max_size=30).launch() # Launch Gradio app with a queue size of 30
 
23
  from peft import LoraConfig, get_peft_model
24
  import time
25
 
26
+ # ---------------------------- Cấu Hình ---------------------------- #
27
 
28
  DESCRIPTION = """\
29
+ # Llama 3.2 3B Instruct với Chức Năng Nâng Cao
30
 
31
+ Llama 3.2 3B phiên bản mới nhất của Meta về các hình ngôn ngữ mở.
32
+ Demo này giới thiệu [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), được tinh chỉnh để theo dõi hướng dẫn.
33
+ Để biết thêm chi tiết, vui lòng xem [bài viết của chúng tôi](https://huggingface.co/blog/llama32).
34
  """
35
 
36
+ MAX_MAX_NEW_TOKENS = 2048 # Số token tối đa có thể tạo ra
37
+ DEFAULT_MAX_NEW_TOKENS = 1024 # Số token tạo ra mặc định
38
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000")) # Độ dài token tối đa cho đầu vào
39
 
40
+ # Xác định thiết bị sử dụng (GPU nếu có, ngược lại CPU)
41
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42
 
43
+ model_id = "meta-llama/Llama-3.2-3B-Instruct" # ID mô hình
44
  tokenizer = AutoTokenizer.from_pretrained(model_id)
45
  model = AutoModelForCausalLM.from_pretrained(
46
  model_id,
47
  device_map="auto",
48
+ torch_dtype=torch.float16, # Sử dụng float16 để tương thích với fp16=True
49
  )
50
  model.to(device)
51
  model.eval()
52
 
53
+ # Khởi tạo pipeline phân tích tâm lý trên GPU nếu
54
  sentiment_pipeline = pipeline(
55
  "sentiment-analysis",
56
  model="nlptown/bert-base-multilingual-uncased-sentiment",
57
  device=0 if torch.cuda.is_available() else -1
58
  )
59
 
60
+ # ---------------------------- Định Nghĩa Hàm ---------------------------- #
61
 
62
  @lru_cache(maxsize=128)
63
  def extract_text_from_webpage(html_content: str) -> str:
64
+ """Trích xuất văn bản hiển thị từ nội dung HTML sử dụng BeautifulSoup."""
65
  soup = BeautifulSoup(html_content, "html.parser")
66
+ # Loại bỏ các thẻ không hiển thị như script, style, header, footer, nav, form, svg
67
  for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
68
  tag.extract()
69
+ # Trích xuất văn bản hiển thị, tách bằng dấu cách và loại bỏ khoảng trắng thừa
70
  visible_text = soup.get_text(separator=' ', strip=True)
71
  return visible_text
72
 
73
  def search(query: str) -> List[Dict[str, Any]]:
74
+ """Thực hiện tìm kiếm trên Google trả về kết quả."""
75
  term = query
76
  all_results = []
77
+ max_chars_per_page = 8000 # Số tự tối đa mỗi trang
78
  headers = {
79
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
80
  }
 
83
  resp = session.get(
84
  url="https://www.google.com/search",
85
  headers=headers,
86
+ params={"q": term, "num": 4}, # Tìm kiếm với 4 kết quả mỗi trang
87
  timeout=5,
88
+ verify=False, # Bỏ qua xác minh SSL
89
  )
90
  resp.raise_for_status()
91
  soup = BeautifulSoup(resp.text, "html.parser")
92
+ result_blocks = soup.find_all("div", attrs={"class": "g"}) # Tìm tất cả các khối kết quả
93
  for result in result_blocks:
94
+ link_tag = result.find("a", href=True) # Tìm thẻ liên kết
95
  if link_tag and 'href' in link_tag.attrs:
96
  link = link_tag["href"]
97
  try:
 
104
  webpage.raise_for_status()
105
  visible_text = extract_text_from_webpage(webpage.text)
106
  if len(visible_text) > max_chars_per_page:
107
+ visible_text = visible_text[:max_chars_per_page] # Cắt văn bản nếu quá dài
108
  all_results.append({"link": link, "text": visible_text})
109
  except requests.exceptions.RequestException:
110
+ all_results.append({"link": link, "text": "Không thể lấy nội dung."})
111
  except requests.exceptions.RequestException as e:
112
+ all_results.append({"link": "N/A", "text": "Không thể thực hiện tìm kiếm."})
113
  return all_results
114
 
115
  def summarize_text(text: str, max_length: int = 150) -> str:
116
+ """Tóm tắt văn bản sử dụng mô hình Llama."""
117
  conversation = [
118
+ {"role": "user", "content": f"Hãy tóm tắt đoạn văn sau: {text}"}
119
  ]
120
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
121
  input_ids = input_ids.to(device)
 
138
  return summary
139
 
140
  def analyze_sentiment(text: str) -> str:
141
+ """Phân tích tâm của văn bản sử dụng mô hình."""
142
  result = sentiment_pipeline(text)
143
  sentiment = result[0]['label']
144
  score = result[0]['score']
145
+ return f"🟢 **Tâm lý**: {sentiment} (Điểm: {score:.2f})"
146
 
147
  def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
148
  """
149
+ Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming.
150
  """
151
+ # Xây dựng lịch sử cuộc trò chuyện
152
  conversation = []
153
  for user, assistant in chat_history:
154
  conversation.extend([
155
  {"role": "user", "content": user},
156
  {"role": "assistant", "content": assistant},
157
  ])
158
+ conversation.append({"role": "user", "content": prompt}) # Thêm tin nhắn của người dùng
159
 
160
+ # Chuẩn bị input_ids từ tokenizer
161
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
162
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
163
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # Cắt input nếu quá dài
164
+ gr.Warning(f"Đã cắt bỏ phần cuộc trò chuyện vì vượt quá {MAX_INPUT_TOKEN_LENGTH} token.")
165
+ input_ids = input_ids.to(device) # Di chuyển input tới thiết bị
166
 
167
+ # Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực
168
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
169
  generate_kwargs = {
170
  "input_ids": input_ids,
 
177
  "num_beams": 1,
178
  "repetition_penalty": repetition_penalty,
179
  }
180
+ t = Thread(target=model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản
181
  t.start()
182
 
183
+ # Stream văn bản được tạo ra
184
  outputs = []
185
  for text in streamer:
186
  outputs.append(text)
 
189
  @lru_cache(maxsize=128)
190
  def process_query(query: str) -> Dict[str, Any]:
191
  """
192
+ Xác định hàm nào sẽ được gọi dựa trên truy vấn của người dùng.
193
  """
194
+ # Định nghĩa các từ khóa hoặc mẫu để xác định hàm
195
  web_search_keywords = ["search", "find", "lookup", "google"]
196
  general_query_keywords = ["explain", "describe", "tell me about", "what is", "how to"]
197
  summarize_keywords = ["summarize", "summarise", "brief", "short"]
 
226
 
227
  def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
228
  """
229
+ Thực thi hàm phù hợp dựa trên lời gọi hàm.
230
  """
231
  function_name = function_call["name"]
232
  arguments = function_call["arguments"]
233
 
234
  if function_name == "web_search":
235
  query = arguments["query"]
236
+ yield "🔍 Đang thực hiện tìm kiếm trên web..."
237
  web_results = search(query)
238
  if not web_results:
239
+ yield "⚠️ Không tìm thấy kết quả."
240
  return
241
+ # Tóm tắt kết quả tìm kiếm
242
+ web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 ** tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
243
  if not web_summary:
244
+ web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
245
 
246
+ # Trả về kết quả tìm kiếm cho người dùng
247
+ yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
248
 
249
  elif function_name == "summarize_query":
250
+ # Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
251
  query = arguments["prompt"]
252
+ yield "🔍 Đang thực hiện tìm kiếm để tóm tắt..."
253
  web_results = search(query)
254
  if not web_results:
255
+ yield "⚠️ Không tìm thấy kết quả để tóm tắt."
256
  return
257
+ # Lấy nội dung từ kết quả tìm kiếm để tóm tắt
258
+ combined_text = ' '.join([res['text'] for res in web_results if res['text'] != "Không thể lấy nội dung."])
259
  if not combined_text:
260
+ yield "⚠️ Không nội dung để tóm tắt."
261
  return
262
+ # Tóm tắt nội dung đã lấy
263
+ yield "📝 Đang tóm tắt thông tin..."
264
  summary = summarize_text(combined_text)
265
+ yield "📄 **Tóm tắt:**\n" + summary
266
 
267
  elif function_name == "sentiment_analysis":
268
  prompt_text = arguments["prompt"]
269
+ yield "📊 Đang phân tích tâm lý..."
270
  sentiment = analyze_sentiment(prompt_text)
271
  yield sentiment
272
 
273
  elif function_name == "train_model":
274
  prompt_text = arguments["prompt"]
275
+ yield "📊 Đang huấn luyện mô hình..."
276
  training_result = run_training()
277
  yield training_result
278
 
279
  elif function_name in ["general_query", "hard_query"]:
280
  prompt_text = arguments["prompt"]
281
+ yield "🤖 Đang tạo phản hồi..."
282
+ # Tạo phản hồi sử dụng mô hình Llama
283
  response_generator = generate_response(
284
  prompt=prompt_text,
285
  chat_history=chat_history,
 
293
  yield response
294
 
295
  else:
296
+ yield "⚠️ Lời gọi hàm không được nhận dạng."
297
 
298
+ # ---------------------------- Huấn luyện ---------------------------- #
299
 
300
+ # Đường dẫn lưu checkpoint
301
  CHECKPOINT_DIR = "./checkpoints"
302
  if not os.path.exists(CHECKPOINT_DIR):
303
  os.makedirs(CHECKPOINT_DIR)
304
 
305
+ # Tải Dataset (CPU)
306
  dataset = load_dataset('vntc/wiki-mini-corpus')
307
 
308
+ # Chia Dataset thành train validation (CPU)
309
  split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
310
  train_dataset = split_dataset['train']
311
  validation_dataset = split_dataset['test']
312
 
313
+ # Tiền Xử Lý Văn Bản (CPU)
314
  def preprocess_function(examples):
315
  passages = [passage.lower().strip() for passage in examples['passage']]
316
  return {'passage': passages}
 
318
  processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
319
  processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
320
 
321
+ # Đảm bảo tokenizer pad_token
322
  if tokenizer.pad_token is None:
323
  tokenizer.pad_token = tokenizer.eos_token
324
 
 
334
  tokenized_train = processed_train.map(tokenize_function, batched=True)
335
  tokenized_validation = processed_validation.map(tokenize_function, batched=True)
336
 
337
+ # Thêm trường 'labels' (CPU)
338
  def add_labels(examples):
339
  examples['labels'] = examples['input_ids'].copy()
340
  return examples
 
342
  tokenized_train = tokenized_train.map(add_labels, batched=True)
343
  tokenized_validation = tokenized_validation.map(add_labels, batched=True)
344
 
345
+ # Loại bỏ các cột không cần thiết (CPU)
346
  tokenized_train = tokenized_train.remove_columns(['passage'])
347
  tokenized_validation = tokenized_validation.remove_columns(['passage'])
348
 
349
+ # Định dạng dữ liệu cho PyTorch (CPU)
350
  tokenized_train.set_format('torch')
351
  tokenized_validation.set_format('torch')
352
 
353
+ # Tạo DatasetDict (CPU)
354
  final_dataset = {
355
  'train': tokenized_train,
356
  'validation': tokenized_validation
357
  }
358
 
359
+ # Định Nghĩa TrainerCallback để Lưu Checkpoint Nhanh Hơn
360
  class SaveCheckpointCallback(TrainerCallback):
361
+ def on_save(self, args, state, control, **kwargs):
362
+ checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
363
+ print(f"Lưu checkpoint tại: {checkpoint_path}")
364
+ kwargs['trainer'].save_model(checkpoint_path)
365
+ return control # Trả về đối tượng control hiện tại
366
+
367
+ # Tải hình đã được pretrained
 
 
368
  pretrained = AutoModelForCausalLM.from_pretrained(
369
  model_id,
370
  device_map="auto",
 
374
 
375
  data_collator = DataCollatorForLanguageModeling(
376
  tokenizer=tokenizer,
377
+ mlm=False, # Vì bạn đang thực hiện Causal LM
378
  pad_to_multiple_of=8
379
  )
380
 
381
  def get_step_done() -> int:
382
  """
383
+ Lấy số bước huấn luyện đã hoàn thành từ checkpoint mới nhất trong thư mục lưu trữ.
384
 
385
  Returns:
386
+ int: Số bước đã hoàn thành. Trả về 0 nếu không tìm thấy checkpoint.
387
  """
388
  checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')]
389
  if not checkpoints:
390
  return 0
391
  try:
392
+ # Tìm checkpoint mới nhất dựa trên số bước
393
  latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[1]))
394
  step_done = int(latest_checkpoint.split('-')[1])
395
  return step_done
396
  except (IndexError, ValueError) as e:
397
+ print(f"Lỗi khi phân tích tên checkpoint: {e}")
398
  return 0
399
 
400
+ # Tải Cấu Hình Mô Hình với LoRA (GPU)
401
  lora_config = LoraConfig(
402
  r=8,
403
  lora_alpha=32,
 
412
  @spaces.GPU(duration=30, queue=False)
413
  def run_training() -> str:
414
  """
415
+ Hàm huấn luyện hình sử dụng GPU với thời gian hạn chế.
416
 
417
  Returns:
418
+ str: Thông báo kết quả huấn luyện.
419
  """
420
 
421
+ # Cấu Hình TrainingArguments (GPU)
422
  training_args = TrainingArguments(
423
  output_dir=CHECKPOINT_DIR,
424
  per_device_train_batch_size=4,
425
  per_device_eval_batch_size=4,
426
  gradient_accumulation_steps=8,
427
  num_train_epochs=3,
428
+ max_steps=300, # Đặt tổng số bước huấn luyện
429
  learning_rate=3e-4,
430
  weight_decay=0.01,
431
+ logging_steps=1, # Ghi log sau mỗi bước
432
+ eval_strategy="steps", # Đánh giá sau mỗi vài bước
433
+ eval_steps=5, # Đánh giá sau mỗi 5 bước
434
+ save_strategy="steps", # Lưu checkpoint sau mỗi vài bước
435
+ save_steps=5, # Lưu checkpoint sau mỗi 5 bước
436
+ save_total_limit=5, # Giới hạn số lượng checkpoint lưu trữ
437
+ fp16=True, # Kích hoạt huấn luyện hỗn hợp độ chính xác
438
  report_to="none",
439
  load_best_model_at_end=True,
440
  )
441
 
442
+ # Tạo Trainer (GPU)
443
  trainer = Trainer(
444
  model=pretrained_model,
445
  args=training_args,
 
447
  eval_dataset=final_dataset['validation'],
448
  tokenizer=tokenizer,
449
  data_collator=data_collator,
450
+ callbacks=[SaveCheckpointCallback()], # Thêm callback
451
  )
452
 
453
+ # Kiểm tra nếu checkpoint
454
  steps_done = get_step_done()
455
  if steps_done > 0:
456
+ # Xác định checkpoint mới nhất dựa trên số bước
457
  latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
458
  if os.path.exists(latest_checkpoint):
459
+ print(f"Đang tiếp tục huấn luyện từ checkpoint: {latest_checkpoint}")
460
  trainer.train(resume_from_checkpoint=latest_checkpoint)
461
  else:
462
+ print(f"Checkpoint {latest_checkpoint} không tồn tại. Bắt đầu huấn luyện từ đầu.")
463
  trainer.train()
464
  else:
465
  trainer.train()
466
 
467
+ # Lưu checkpoint sau khi huấn luyện
468
  trainer.save_model(CHECKPOINT_DIR)
469
+ return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
470
 
471
+ # Hàm Tự Động Hóa Việc Gọi Lặp Lại Hàm Huấn Luyện
472
  @spaces.GPU(duration=30, queue=False)
473
  def continuous_training(total_steps=300, steps_per_call=50):
474
  """
475
+ Hàm tự động gọi lại `run_training` để hoàn thành quá trình huấn luyện.
476
 
477
  Args:
478
+ total_steps (int): Tổng số bước huấn luyện mong muốn.
479
+ steps_per_call (int): Số bước huấn luyện mỗi lần gọi hàm.
480
  """
481
  steps_done = get_step_done()
482
  while steps_done < total_steps:
483
  remaining_steps = total_steps - steps_done
484
  current_steps = min(steps_per_call, remaining_steps)
485
+ print(f"Bắt đầu huấn luyện cho {current_steps} bước.")
486
 
487
+ # Cập nhật TrainingArguments để huấn luyện cho current_steps bước
488
  training_args = TrainingArguments(
489
  output_dir=CHECKPOINT_DIR,
490
  per_device_train_batch_size=4,
491
  per_device_eval_batch_size=4,
492
  gradient_accumulation_steps=8,
493
+ num_train_epochs=1, # Huấn luyện trong một epoch
494
  max_steps=current_steps,
495
  learning_rate=3e-4,
496
  weight_decay=0.01,
 
505
  load_best_model_at_end=True,
506
  )
507
 
508
+ # Tạo Trainer với TrainingArguments mới
509
  trainer = Trainer(
510
  model=pretrained_model,
511
  args=training_args,
 
516
  callbacks=[SaveCheckpointCallback()],
517
  )
518
 
519
+ # Tiếp tục huấn luyện từ checkpoint hiện tại
520
  if steps_done > 0:
521
  latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
522
  if os.path.exists(latest_checkpoint):
523
+ print(f"Đang tiếp tục huấn luyện từ checkpoint: {latest_checkpoint}")
524
  trainer.train(resume_from_checkpoint=latest_checkpoint)
525
  else:
526
+ print(f"Checkpoint {latest_checkpoint} không tồn tại. Bắt đầu huấn luyện từ đầu.")
527
  trainer.train()
528
  else:
529
  trainer.train()
530
 
531
  steps_done = get_step_done()
532
+ print(f"Đã huấn luyện {steps_done} / {total_steps} bước.")
533
 
534
+ # Kiểm tra nếu đã đạt số bước mong muốn
535
  if steps_done >= total_steps:
536
+ print("Đã hoàn thành toàn bộ quá trình huấn luyện.")
537
  break
538
 
539
+ # Chờ một khoảng thời gian trước khi gọi lại (tùy thuộc vào yêu cầu của hệ thống)
540
+ time.sleep(2) # Thời gian chờ thể điều chỉnh
541
 
542
+ # ---------------------------- Giao Diện Gradio ---------------------------- #
543
 
544
  @spaces.GPU(duration=30, queue=False)
545
  def generate(
 
552
  repetition_penalty: float = 1.2,
553
  ) -> Iterator[str]:
554
  """
555
+ Hàm chính để xử đầu vào của người dùng và tạo phản hồi.
556
  """
557
+ # Thông báo về việc phân tích đầu vào
558
+ yield "🔍 Đang phân tích truy vấn của bạn..."
559
 
560
+ # Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
561
  function_call = process_query(message)
562
 
563
+ # Thông báo về hàm được chọn
564
  if function_call["name"] == "web_search":
565
+ yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
566
  elif function_call["name"] == "summarize_query":
567
+ yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
568
  elif function_call["name"] == "sentiment_analysis":
569
+ yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
570
  elif function_call["name"] in ["general_query", "hard_query"]:
571
+ yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
572
  elif function_call["name"] == "train_model":
573
+ yield "🛠️ Đã chọn chức năng: Huấn luyện mô hình."
574
  else:
575
+ yield "⚠️ Không thể xác định chức năng phù hợp."
576
 
577
+ # Xử lời gọi hàm sinh phản hồi tương ứng
578
  response_iterator = handle_functions(
579
  function_call=function_call,
580
  prompt=message,
 
589
  for response in response_iterator:
590
  yield response
591
 
592
+ # Định nghĩa các dụ để hướng dẫn người dùng
593
  EXAMPLES = [
594
+ ["Xin chào! Bạn khỏe không?"],
595
+ ["Bạn thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"],
596
+ ["Giải thích cốt truyện của Lọ Lem trong một câu."],
597
+ ["Một người đàn ông cần bao nhiêu giờ để ăn một chiếc máy bay trực thăng?"],
598
+ ["Viết một bài báo 100 từ về 'Lợi ích của nguồn mở trong nghiên cứu AI'"],
599
+ ["Tìm cung cấp cho tôi tin tức mới nhất về năng lượng tái tạo."],
600
+ ["Tìm thông tin về Rạn san hô Great Barrier Reef."],
601
+ ["Tóm tắt nội dung về trí tuệ nhân tạo."],
602
+ ["Phân tích tâm của đoạn văn sau: Tôi rất vui khi được gặp bạn hôm nay!"],
603
+ ["Huấn luyện mô hình!"],
604
  ]
605
 
606
+ # Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
607
  chat_interface = gr.ChatInterface(
608
+ fn=generate, # Hàm được gọi khi có tương tác từ người dùng
609
  additional_inputs=[
610
  gr.Slider(
611
+ label="Số token mới tối đa",
612
  minimum=1,
613
  maximum=MAX_MAX_NEW_TOKENS,
614
  step=1,
615
  value=DEFAULT_MAX_NEW_TOKENS,
616
  ),
617
  gr.Slider(
618
+ label="Nhiệt độ",
619
  minimum=0.1,
620
  maximum=4.0,
621
  step=0.1,
622
  value=0.6,
623
  ),
624
  gr.Slider(
625
+ label="Top-p (nucleus sampling)",
626
  minimum=0.05,
627
  maximum=1.0,
628
  step=0.05,
 
636
  value=50,
637
  ),
638
  gr.Slider(
639
+ label="Hình phạt sự lặp lại",
640
  minimum=1.0,
641
  maximum=2.0,
642
  step=0.05,
643
  value=1.2,
644
  ),
645
  ],
646
+ stop_btn=None, # Không nút dừng
647
+ examples=EXAMPLES, # Các dụ được hiển thị cho người dùng
648
+ cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ
649
  title="🤖 OpenGPT-4o Chatbot",
650
+ description="Một trợ AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản và phân tích tâm lý.",
651
+ theme="default", # thể thay đổi theme để giao diện đẹp hơn
652
  )
653
 
654
+ # Tạo giao diện chính của Gradio với CSS tùy chỉnh
655
  with gr.Blocks(css="""
656
  .gradio-container {
657
+ background-color: #f0f2f5; /* Màu nền nhẹ nhàng */
658
  }
659
  .gradio-container h1 {
660
+ color: #4a90e2; /* Màu xanh dương cho tiêu đề */
661
  }
662
  .gradio-container .gr-button {
663
+ background-color: #4a90e2; /* Màu xanh dương cho nút */
664
+ color: white; /* Màu chữ trắng trên nút */
665
  }
666
  .gradio-container .gr-slider__label {
667
+ color: #333333; /* Màu chữ đen cho nhãn slider */
668
  }
669
  .gradio-container .gr-chatbot {
670
+ border: 2px solid #4a90e2; /* Viền xanh dương cho chatbot */
671
+ border-radius: 10px; /* Bo góc viền chatbot */
672
+ padding: 10px; /* Khoảng cách bên trong chatbot */
673
+ background-color: #ffffff; /* Màu nền trắng cho chatbot */
674
  }
675
  """, fill_height=True) as demo:
676
+ gr.Markdown(DESCRIPTION) # Hiển thị mô tả
677
+ chat_interface.render() # Hiển thị giao diện trò chuyện
678
 
679
  if __name__ == "__main__":
680
+ demo.queue(max_size=30).launch() # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là 30