hoduyquocbao commited on
Commit
0abfd6d
·
1 Parent(s): 1655dfc

fix errors

Browse files
Files changed (1) hide show
  1. app.py +196 -196
app.py CHANGED
@@ -23,54 +23,56 @@ from datasets import load_dataset
23
  from peft import LoraConfig, get_peft_model
24
  import time
25
 
26
- # ---------------------------- Cấu Hình ---------------------------- #
27
 
28
  DESCRIPTION = """\
29
- # Llama 3.2 3B Instruct với Chức Năng Nâng Cao
30
 
31
- Llama 3.2 3B phiên bản mới nhất của Meta về các hình ngôn ngữ mở.
32
- Demo này giới thiệu [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), được tinh chỉnh để theo dõi hướng dẫn.
33
- Để biết thêm chi tiết, vui lòng xem [bài viết của chúng tôi](https://huggingface.co/blog/llama32).
34
  """
35
 
36
- MAX_MAX_NEW_TOKENS = 2048 # Số token tối đa có thể tạo ra
37
- DEFAULT_MAX_NEW_TOKENS = 1024 # Số token tạo ra mặc định
38
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000")) # Độ dài token tối đa cho đầu vào
39
 
40
- # Xác định thiết bị sử dụng (GPU nếu có, ngược lại CPU)
41
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42
 
43
- model_id = "meta-llama/Llama-3.2-3B-Instruct" # ID mô hình, đảm bảo đây là ID mô hình đúng
44
- tokenizer = AutoTokenizer.from_pretrained(model_id) # Tải tokenizer từ Hugging Face
45
  model = AutoModelForCausalLM.from_pretrained(
46
  model_id,
47
  device_map="auto",
48
- torch_dtype=torch.bfloat16, # Sử dụng dtype phù hợp để tiết kiệm bộ nhớ
 
 
 
 
 
 
 
 
 
49
  )
50
- model.to(device) # Di chuyển mô hình tới thiết bị đã chọn
51
- model.eval() # Đặt mô hình ở chế độ đánh giá
52
-
53
- # Khởi tạo pipeline phân tích tâm lý
54
- sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
55
 
56
- # ---------------------------- Định Nghĩa Hàm ---------------------------- #
57
 
58
  @lru_cache(maxsize=128)
59
  def extract_text_from_webpage(html_content: str) -> str:
60
- """Trích xuất văn bản hiển thị từ nội dung HTML sử dụng BeautifulSoup."""
61
  soup = BeautifulSoup(html_content, "html.parser")
62
- # Loại bỏ các thẻ không hiển thị như script, style, header, footer, nav, form, svg
63
  for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
64
  tag.extract()
65
- # Trích xuất văn bản hiển thị, tách bằng dấu cách và loại bỏ khoảng trắng thừa
66
  visible_text = soup.get_text(separator=' ', strip=True)
67
  return visible_text
68
 
69
  def search(query: str) -> List[Dict[str, Any]]:
70
- """Thực hiện tìm kiếm trên Google trả về kết quả."""
71
  term = query
72
  all_results = []
73
- max_chars_per_page = 8000 # Số tự tối đa mỗi trang
74
  headers = {
75
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
76
  }
@@ -79,15 +81,15 @@ def search(query: str) -> List[Dict[str, Any]]:
79
  resp = session.get(
80
  url="https://www.google.com/search",
81
  headers=headers,
82
- params={"q": term, "num": 4}, # Tìm kiếm với 4 kết quả mỗi trang
83
  timeout=5,
84
- verify=False, # Bỏ qua xác minh SSL
85
  )
86
- resp.raise_for_status() # Kiểm tra phản hồi HTTP
87
  soup = BeautifulSoup(resp.text, "html.parser")
88
- result_blocks = soup.find_all("div", attrs={"class": "g"}) # Tìm tất cả các khối kết quả
89
  for result in result_blocks:
90
- link_tag = result.find("a", href=True) # Tìm thẻ liên kết
91
  if link_tag and 'href' in link_tag.attrs:
92
  link = link_tag["href"]
93
  try:
@@ -100,22 +102,22 @@ def search(query: str) -> List[Dict[str, Any]]:
100
  webpage.raise_for_status()
101
  visible_text = extract_text_from_webpage(webpage.text)
102
  if len(visible_text) > max_chars_per_page:
103
- visible_text = visible_text[:max_chars_per_page] # Cắt văn bản nếu quá dài
104
  all_results.append({"link": link, "text": visible_text})
105
  except requests.exceptions.RequestException:
106
- all_results.append({"link": link, "text": "Không thể lấy nội dung."})
107
  except requests.exceptions.RequestException as e:
108
- all_results.append({"link": "N/A", "text": "Không thể thực hiện tìm kiếm."})
109
  return all_results
110
 
111
  def summarize_text(text: str, max_length: int = 150) -> str:
112
- """Tóm tắt văn bản sử dụng mô hình Llama."""
113
  conversation = [
114
- {"role": "user", "content": f"Hãy tóm tắt đoạn văn sau: {text}"}
115
  ]
116
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
117
  input_ids = input_ids.to(device)
118
-
119
  summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
120
  summary_kwargs = {
121
  "input_ids": input_ids,
@@ -127,40 +129,40 @@ def summarize_text(text: str, max_length: int = 150) -> str:
127
  }
128
  t = Thread(target=model.generate, kwargs=summary_kwargs)
129
  t.start()
130
-
131
  summary = ""
132
  for new_text in summary_streamer:
133
  summary += new_text
134
  return summary
135
 
136
  def analyze_sentiment(text: str) -> str:
137
- """Phân tích tâm của văn bản sử dụng mô hình."""
138
  result = sentiment_pipeline(text)
139
  sentiment = result[0]['label']
140
  score = result[0]['score']
141
- return f"🟢 **Tâm lý**: {sentiment} (Điểm: {score:.2f})"
142
 
143
  def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
144
  """
145
- Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming.
146
  """
147
- # Xây dựng lịch sử cuộc trò chuyện
148
  conversation = []
149
  for user, assistant in chat_history:
150
  conversation.extend([
151
  {"role": "user", "content": user},
152
  {"role": "assistant", "content": assistant},
153
  ])
154
- conversation.append({"role": "user", "content": prompt}) # Thêm tin nhắn của người dùng
155
-
156
- # Chuẩn bị input_ids từ tokenizer
157
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
158
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
159
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # Cắt input nếu quá dài
160
- gr.Warning(f"Đã cắt bỏ phần cuộc trò chuyện vì vượt quá {MAX_INPUT_TOKEN_LENGTH} token.")
161
- input_ids = input_ids.to(device) # Di chuyển input tới thiết bị
162
-
163
- # Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực
164
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
165
  generate_kwargs = {
166
  "input_ids": input_ids,
@@ -173,10 +175,10 @@ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_
173
  "num_beams": 1,
174
  "repetition_penalty": repetition_penalty,
175
  }
176
- t = Thread(target=model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản
177
  t.start()
178
-
179
- # Stream văn bản được tạo ra
180
  outputs = []
181
  for text in streamer:
182
  outputs.append(text)
@@ -185,17 +187,17 @@ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_
185
  @lru_cache(maxsize=128)
186
  def process_query(query: str) -> Dict[str, Any]:
187
  """
188
- Xác định hàm nào sẽ được gọi dựa trên truy vấn của người dùng.
189
  """
190
- # Định nghĩa các từ khóa hoặc mẫu để xác định hàm
191
- web_search_keywords = ["tìm kiếm", "tìm", "tra cứu", "google", "lookup"]
192
- general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
193
- summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
194
- sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]
195
- train_keywords = ["huấn luyện"]
196
-
197
- query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh
198
-
199
  if any(keyword in query_lower for keyword in web_search_keywords):
200
  function_name = "web_search"
201
  arguments = {"query": query}
@@ -214,7 +216,7 @@ def process_query(query: str) -> Dict[str, Any]:
214
  else:
215
  function_name = "hard_query"
216
  arguments = {"prompt": query}
217
-
218
  return {
219
  "name": function_name,
220
  "arguments": arguments
@@ -222,60 +224,60 @@ def process_query(query: str) -> Dict[str, Any]:
222
 
223
  def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
224
  """
225
- Thực thi hàm phù hợp dựa trên lời gọi hàm.
226
  """
227
  function_name = function_call["name"]
228
  arguments = function_call["arguments"]
229
-
230
  if function_name == "web_search":
231
  query = arguments["query"]
232
- yield "🔍 Đang thực hiện tìm kiếm trên web..."
233
  web_results = search(query)
234
  if not web_results:
235
- yield "⚠️ Không tìm thấy kết quả."
236
  return
237
- # Tóm tắt kết quả tìm kiếm
238
- web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 ** tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
239
  if not web_summary:
240
- web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
241
-
242
- # Trả về kết quả tìm kiếm cho người dùng
243
- yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
244
-
245
  elif function_name == "summarize_query":
246
- # Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
247
  query = arguments["prompt"]
248
- yield "🔍 Đang thực hiện tìm kiếm để tóm tắt..."
249
  web_results = search(query)
250
  if not web_results:
251
- yield "⚠️ Không tìm thấy kết quả để tóm tắt."
252
  return
253
- # Lấy nội dung từ kết quả tìm kiếm để tóm tắt
254
- combined_text = ' '.join([res['text'] for res in web_results if res['text'] != "Không thể lấy nội dung."])
255
  if not combined_text:
256
- yield "⚠️ Không nội dung để tóm tắt."
257
  return
258
- # Tóm tắt nội dung đã lấy
259
- yield "📝 Đang tóm tắt thông tin..."
260
  summary = summarize_text(combined_text)
261
- yield "📄 **Tóm tắt:**\n" + summary
262
-
263
  elif function_name == "sentiment_analysis":
264
  prompt_text = arguments["prompt"]
265
- yield "📊 Đang phân tích tâm lý..."
266
  sentiment = analyze_sentiment(prompt_text)
267
  yield sentiment
268
-
269
  elif function_name == "train_model":
270
  prompt_text = arguments["prompt"]
271
- yield "📊 Đang huấn luyện mô hình..."
272
  training_result = run_training()
273
  yield training_result
274
-
275
  elif function_name in ["general_query", "hard_query"]:
276
  prompt_text = arguments["prompt"]
277
- yield "🤖 Đang tạo phản hồi..."
278
- # Tạo phản hồi sử dụng mô hình Llama
279
  response_generator = generate_response(
280
  prompt=prompt_text,
281
  chat_history=chat_history,
@@ -287,26 +289,26 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
287
  )
288
  for response in response_generator:
289
  yield response
290
-
291
  else:
292
- yield "⚠️ Lời gọi hàm không được nhận dạng."
293
 
294
- # ---------------------------- Huấn luyện ---------------------------- #
295
 
296
- # Đường dẫn lưu checkpoint
297
  CHECKPOINT_DIR = "./checkpoints"
298
  if not os.path.exists(CHECKPOINT_DIR):
299
  os.makedirs(CHECKPOINT_DIR)
300
 
301
- # Tải Dataset (CPU)
302
  dataset = load_dataset('vntc/wiki-mini-corpus')
303
 
304
- # Chia Dataset thành train validation (CPU)
305
  split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
306
  train_dataset = split_dataset['train']
307
  validation_dataset = split_dataset['test']
308
 
309
- # Tiền Xử Lý Văn Bản (CPU)
310
  def preprocess_function(examples):
311
  passages = [passage.lower().strip() for passage in examples['passage']]
312
  return {'passage': passages}
@@ -314,7 +316,7 @@ def preprocess_function(examples):
314
  processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
315
  processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
316
 
317
- # Đảm bảo tokenizer pad_token
318
  if tokenizer.pad_token is None:
319
  tokenizer.pad_token = tokenizer.eos_token
320
 
@@ -324,12 +326,13 @@ def tokenize_function(examples):
324
  padding='max_length',
325
  truncation=True,
326
  max_length=512,
 
327
  )
328
 
329
  tokenized_train = processed_train.map(tokenize_function, batched=True)
330
  tokenized_validation = processed_validation.map(tokenize_function, batched=True)
331
 
332
- # Thêm trường 'labels' (CPU)
333
  def add_labels(examples):
334
  examples['labels'] = examples['input_ids'].copy()
335
  return examples
@@ -337,31 +340,31 @@ def add_labels(examples):
337
  tokenized_train = tokenized_train.map(add_labels, batched=True)
338
  tokenized_validation = tokenized_validation.map(add_labels, batched=True)
339
 
340
- # Loại bỏ các cột không cần thiết (CPU)
341
  tokenized_train = tokenized_train.remove_columns(['passage'])
342
  tokenized_validation = tokenized_validation.remove_columns(['passage'])
343
 
344
- # Định dạng dữ liệu cho PyTorch (CPU)
345
  tokenized_train.set_format('torch')
346
  tokenized_validation.set_format('torch')
347
 
348
- # Tạo DatasetDict (CPU)
349
  final_dataset = {
350
  'train': tokenized_train,
351
  'validation': tokenized_validation
352
  }
353
 
354
- # Định Nghĩa TrainerCallback để Lưu Checkpoint Nhanh Hơn
355
  class SaveCheckpointCallback(TrainerCallback):
356
  def on_step_end(self, args, state, control, **kwargs):
357
  if state.global_step % args.save_steps == 0 and state.global_step != 0:
358
  checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
359
- print(f"Lưu checkpoint tại: {checkpoint_path}")
360
- trainer = kwargs['trainer'] # Truy cập trainer từ kwargs
361
  trainer.save_model(checkpoint_path)
362
- return control # Trả về đối tượng control hiện tại
363
 
364
- # Tải hình đã được pretrained
365
  pretrained = AutoModelForCausalLM.from_pretrained(
366
  model_id,
367
  device_map="auto",
@@ -371,31 +374,30 @@ pretrained = AutoModelForCausalLM.from_pretrained(
371
 
372
  data_collator = DataCollatorForLanguageModeling(
373
  tokenizer=tokenizer,
374
- mlm=False, # Vì bạn đang thực hiện Causal LM
375
  pad_to_multiple_of=8
376
  )
377
 
378
  def get_step_done() -> int:
379
  """
380
- Lấy số bước huấn luyện đã hoàn thành từ checkpoint mới nhất trong thư mục lưu trữ.
381
 
382
  Returns:
383
- int: Số bước đã hoàn thành. Trả về 0 nếu không tìm thấy checkpoint.
384
  """
385
  checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')]
386
  if not checkpoints:
387
  return 0
388
  try:
389
- # Tìm checkpoint mới nhất dựa trên số bước
390
  latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[1]))
391
  step_done = int(latest_checkpoint.split('-')[1])
392
  return step_done
393
  except (IndexError, ValueError) as e:
394
- print(f"Lỗi khi phân tích tên checkpoint: {e}")
395
  return 0
396
 
397
-
398
- # Tải và Cấu Hình Mô Hình với LoRA (GPU)
399
  lora_config = LoraConfig(
400
  r=8,
401
  lora_alpha=32,
@@ -410,34 +412,34 @@ print(pretrained_model)
410
  @spaces.GPU(duration=30, queue=False)
411
  def run_training() -> str:
412
  """
413
- Hàm huấn luyện hình sử dụng GPU với thời gian hạn chế.
414
 
415
  Returns:
416
- str: Thông báo kết quả huấn luyện.
417
  """
418
 
419
- # Cấu Hình TrainingArguments (GPU)
420
  training_args = TrainingArguments(
421
  output_dir=CHECKPOINT_DIR,
422
  per_device_train_batch_size=4,
423
  per_device_eval_batch_size=4,
424
  gradient_accumulation_steps=8,
425
  num_train_epochs=3,
426
- max_steps=300, # Đặt tổng số bước huấn luyện
427
  learning_rate=3e-4,
428
  weight_decay=0.01,
429
- logging_steps=1, # Ghi log sau mỗi 10 bước
430
- eval_strategy="steps", # Đánh giá sau mỗi vài bước
431
- eval_steps=5, # Đánh giá sau mỗi 50 bước
432
- save_strategy="steps", # Lưu checkpoint sau mỗi vài bước
433
- save_steps=5, # Lưu checkpoint sau mỗi 50 bước
434
- save_total_limit=5, # Giới hạn số lượng checkpoint lưu trữ
435
  fp16=True,
436
  report_to="none",
437
  load_best_model_at_end=True,
438
  )
439
 
440
- # Tạo Trainer (GPU)
441
  trainer = Trainer(
442
  model=pretrained_model,
443
  args=training_args,
@@ -445,50 +447,50 @@ def run_training() -> str:
445
  eval_dataset=final_dataset['validation'],
446
  tokenizer=tokenizer,
447
  data_collator=data_collator,
448
- callbacks=[SaveCheckpointCallback()], # Thêm callback
449
  )
450
 
451
- # Kiểm tra nếu checkpoint
452
  steps_done = get_step_done()
453
  if steps_done > 0:
454
- # Xác định checkpoint mới nhất dựa trên số bước
455
  latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
456
  if os.path.exists(latest_checkpoint):
457
- print(f"Đang tiếp tục huấn luyện từ checkpoint: {latest_checkpoint}")
458
  trainer.train(resume_from_checkpoint=latest_checkpoint)
459
  else:
460
- print(f"Checkpoint {latest_checkpoint} không tồn tại. Bắt đầu huấn luyện từ đầu.")
461
  trainer.train()
462
  else:
463
  trainer.train()
464
 
465
- # Lưu checkpoint sau khi huấn luyện
466
  trainer.save_model(CHECKPOINT_DIR)
467
- return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
468
 
469
- # Hàm Tự Động Hóa Việc Gọi Lặp Lại Hàm Huấn Luyện
470
  @spaces.GPU(duration=30, queue=False)
471
  def continuous_training(total_steps=300, steps_per_call=50):
472
  """
473
- Hàm tự động gọi lại `run_training` để hoàn thành quá trình huấn luyện.
474
 
475
  Args:
476
- total_steps (int): Tổng số bước huấn luyện mong muốn.
477
- steps_per_call (int): Số bước huấn luyện mỗi lần gọi hàm.
478
  """
479
  steps_done = get_step_done()
480
  while steps_done < total_steps:
481
  remaining_steps = total_steps - steps_done
482
  current_steps = min(steps_per_call, remaining_steps)
483
- print(f"Bắt đầu huấn luyện cho {current_steps} bước.")
484
 
485
- # Cập nhật TrainingArguments để huấn luyện cho current_steps bước
486
  training_args = TrainingArguments(
487
  output_dir=CHECKPOINT_DIR,
488
  per_device_train_batch_size=4,
489
  per_device_eval_batch_size=4,
490
  gradient_accumulation_steps=8,
491
- num_train_epochs=1, # Huấn luyện trong một epoch
492
  max_steps=current_steps,
493
  learning_rate=3e-4,
494
  weight_decay=0.01,
@@ -503,7 +505,7 @@ def continuous_training(total_steps=300, steps_per_call=50):
503
  load_best_model_at_end=True,
504
  )
505
 
506
- # Tạo Trainer với TrainingArguments mới
507
  trainer = Trainer(
508
  model=pretrained_model,
509
  args=training_args,
@@ -514,30 +516,30 @@ def continuous_training(total_steps=300, steps_per_call=50):
514
  callbacks=[SaveCheckpointCallback()],
515
  )
516
 
517
- # Tiếp tục huấn luyện từ checkpoint hiện tại
518
  if steps_done > 0:
519
  latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
520
  if os.path.exists(latest_checkpoint):
521
- print(f"Đang tiếp tục huấn luyện từ checkpoint: {latest_checkpoint}")
522
  trainer.train(resume_from_checkpoint=latest_checkpoint)
523
  else:
524
- print(f"Checkpoint {latest_checkpoint} không tồn tại. Bắt đầu huấn luyện từ đầu.")
525
  trainer.train()
526
  else:
527
  trainer.train()
528
 
529
  steps_done = get_step_done()
530
- print(f"Đã huấn luyện {steps_done} / {total_steps} bước.")
531
 
532
- # Kiểm tra nếu đã đạt số bước mong muốn
533
  if steps_done >= total_steps:
534
- print("Đã hoàn thành toàn bộ quá trình huấn luyện.")
535
  break
536
 
537
- # Chờ một khoảng thời gian trước khi gọi lại (tùy thuộc vào yêu cầu của hệ thống)
538
- time.sleep(2) # Thời gian chờ thể điều chỉnh
539
 
540
- # ---------------------------- Giao Diện Gradio ---------------------------- #
541
 
542
  @spaces.GPU(duration=30, queue=False)
543
  def generate(
@@ -550,29 +552,29 @@ def generate(
550
  repetition_penalty: float = 1.2,
551
  ) -> Iterator[str]:
552
  """
553
- Hàm chính để xử đầu vào của người dùng và tạo phản hồi.
554
  """
555
- # Thông báo về việc phân tích đầu vào
556
- yield "🔍 Đang phân tích truy vấn của bạn..."
557
 
558
- # Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
559
  function_call = process_query(message)
560
 
561
- # Thông báo về hàm được chọn
562
  if function_call["name"] == "web_search":
563
- yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
564
  elif function_call["name"] == "summarize_query":
565
- yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
566
  elif function_call["name"] == "sentiment_analysis":
567
- yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
568
  elif function_call["name"] in ["general_query", "hard_query"]:
569
- yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
570
  elif function_call["name"] == "train_model":
571
- yield "🛠️ Đã chọn chức năng: Huấn luyện mô hình."
572
  else:
573
- yield "⚠️ Không thể xác định chức năng phù hợp."
574
 
575
- # Xử lời gọi hàm sinh phản hồi tương ứng
576
  response_iterator = handle_functions(
577
  function_call=function_call,
578
  prompt=message,
@@ -587,40 +589,40 @@ def generate(
587
  for response in response_iterator:
588
  yield response
589
 
590
- # Định nghĩa các dụ để hướng dẫn người dùng
591
  EXAMPLES = [
592
- ["Xin chào! Bạn khỏe không?"],
593
- ["Bạn thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"],
594
- ["Giải thích cốt truyện của Lọ Lem trong một câu."],
595
- ["Một người đàn ông cần bao nhiêu giờ để ăn một chiếc máy bay trực thăng?"],
596
- ["Viết một bài báo 100 từ về 'Lợi ích của nguồn mở trong nghiên cứu AI'"],
597
- ["Tìm cung cấp cho tôi tin tức mới nhất về năng lượng tái tạo."],
598
- ["Tìm thông tin về Rạn san hô Great Barrier Reef."],
599
- ["Tóm tắt nội dung về trí tuệ nhân tạo."],
600
- ["Phân tích tâm của đoạn văn sau: Tôi rất vui khi được gặp bạn hôm nay!"],
601
- ["Huấn luyện mô hình!"],
602
  ]
603
 
604
- # Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
605
  chat_interface = gr.ChatInterface(
606
- fn=generate, # Hàm được gọi khi có tương tác từ người dùng
607
  additional_inputs=[
608
  gr.Slider(
609
- label="Số token mới tối đa",
610
  minimum=1,
611
  maximum=MAX_MAX_NEW_TOKENS,
612
  step=1,
613
  value=DEFAULT_MAX_NEW_TOKENS,
614
  ),
615
  gr.Slider(
616
- label="Nhiệt độ",
617
  minimum=0.1,
618
  maximum=4.0,
619
  step=0.1,
620
  value=0.6,
621
  ),
622
  gr.Slider(
623
- label="Top-p (nucleus sampling)",
624
  minimum=0.05,
625
  maximum=1.0,
626
  step=0.05,
@@ -634,47 +636,45 @@ chat_interface = gr.ChatInterface(
634
  value=50,
635
  ),
636
  gr.Slider(
637
- label="Hình phạt sự lặp lại",
638
  minimum=1.0,
639
  maximum=2.0,
640
  step=0.05,
641
  value=1.2,
642
  ),
643
  ],
644
- stop_btn=None, # Không nút dừng
645
- examples=EXAMPLES, # Các dụ được hiển thị cho người dùng
646
- cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ
647
  title="🤖 OpenGPT-4o Chatbot",
648
- description="Một trợ AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản và phân tích tâm lý.",
649
- theme="default", # thể thay đổi theme để giao diện đẹp hơn
650
  )
651
 
652
- # Tạo giao diện chính của Gradio với CSS tùy chỉnh
653
  with gr.Blocks(css="""
654
  .gradio-container {
655
- background-color: #f0f2f5; /* Màu nền nhẹ nhàng */
656
  }
657
  .gradio-container h1 {
658
- color: #4a90e2; /* Màu xanh dương cho tiêu đề */
659
  }
660
  .gradio-container .gr-button {
661
- background-color: #4a90e2; /* Màu xanh dương cho nút */
662
- color: white; /* Màu chữ trắng trên nút */
663
  }
664
  .gradio-container .gr-slider__label {
665
- color: #333333; /* Màu chữ đen cho nhãn slider */
666
  }
667
  .gradio-container .gr-chatbot {
668
- border: 2px solid #4a90e2; /* Viền xanh dương cho chatbot */
669
- border-radius: 10px; /* Bo góc viền chatbot */
670
- padding: 10px; /* Khoảng cách bên trong chatbot */
671
- background-color: #ffffff; /* Màu nền trắng cho chatbot */
672
  }
673
  """, fill_height=True) as demo:
674
- gr.Markdown(DESCRIPTION) # Hiển thị mô tả
675
- # Nút nhân bản không gian (nếu cần thiết)
676
- # gr.DuplicateButton(value="Nhân bản Không gian để sử dụng riêng tư", elem_id="duplicate-button") # Uncomment if gr.DuplicateButton is needed
677
- chat_interface.render() # Hiển thị giao diện trò chuyện
678
 
679
  if __name__ == "__main__":
680
- demo.queue(max_size=30).launch() # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là 30
 
23
  from peft import LoraConfig, get_peft_model
24
  import time
25
 
26
+ # ---------------------------- Configuration ---------------------------- #
27
 
28
  DESCRIPTION = """\
29
+ # Llama 3.2 3B Instruct with Advanced Features
30
 
31
+ Llama 3.2 3B is the latest version from Meta for open language models.
32
+ This demo showcases [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction-following.
33
+ For more details, please see [our blog post](https://huggingface.co/blog/llama32).
34
  """
35
 
36
+ MAX_MAX_NEW_TOKENS = 2048 # Maximum tokens to generate
37
+ DEFAULT_MAX_NEW_TOKENS = 1024 # Default tokens to generate
38
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000")) # Max input token length
39
 
40
+ # Determine device (GPU if available, else CPU)
41
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42
 
43
+ model_id = "meta-llama/Llama-3.2-3B-Instruct" # Model ID
44
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
45
  model = AutoModelForCausalLM.from_pretrained(
46
  model_id,
47
  device_map="auto",
48
+ torch_dtype=torch.float16, # Use float16 for compatibility with fp16=True
49
+ )
50
+ model.to(device)
51
+ model.eval()
52
+
53
+ # Initialize sentiment analysis pipeline on GPU if available
54
+ sentiment_pipeline = pipeline(
55
+ "sentiment-analysis",
56
+ model="nlptown/bert-base-multilingual-uncased-sentiment",
57
+ device=0 if torch.cuda.is_available() else -1
58
  )
 
 
 
 
 
59
 
60
+ # ---------------------------- Function Definitions ---------------------------- #
61
 
62
  @lru_cache(maxsize=128)
63
  def extract_text_from_webpage(html_content: str) -> str:
64
+ """Extract visible text from HTML content using BeautifulSoup."""
65
  soup = BeautifulSoup(html_content, "html.parser")
 
66
  for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
67
  tag.extract()
 
68
  visible_text = soup.get_text(separator=' ', strip=True)
69
  return visible_text
70
 
71
  def search(query: str) -> List[Dict[str, Any]]:
72
+ """Perform a Google search and return results."""
73
  term = query
74
  all_results = []
75
+ max_chars_per_page = 8000 # Max characters per page
76
  headers = {
77
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
78
  }
 
81
  resp = session.get(
82
  url="https://www.google.com/search",
83
  headers=headers,
84
+ params={"q": term, "num": 4}, # 4 results per page
85
  timeout=5,
86
+ verify=False, # Skip SSL verification
87
  )
88
+ resp.raise_for_status()
89
  soup = BeautifulSoup(resp.text, "html.parser")
90
+ result_blocks = soup.find_all("div", attrs={"class": "g"})
91
  for result in result_blocks:
92
+ link_tag = result.find("a", href=True)
93
  if link_tag and 'href' in link_tag.attrs:
94
  link = link_tag["href"]
95
  try:
 
102
  webpage.raise_for_status()
103
  visible_text = extract_text_from_webpage(webpage.text)
104
  if len(visible_text) > max_chars_per_page:
105
+ visible_text = visible_text[:max_chars_per_page]
106
  all_results.append({"link": link, "text": visible_text})
107
  except requests.exceptions.RequestException:
108
+ all_results.append({"link": link, "text": "Could not retrieve content."})
109
  except requests.exceptions.RequestException as e:
110
+ all_results.append({"link": "N/A", "text": "Could not perform search."})
111
  return all_results
112
 
113
  def summarize_text(text: str, max_length: int = 150) -> str:
114
+ """Summarize text using the Llama model."""
115
  conversation = [
116
+ {"role": "user", "content": f"Please summarize the following text: {text}"}
117
  ]
118
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
119
  input_ids = input_ids.to(device)
120
+
121
  summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
122
  summary_kwargs = {
123
  "input_ids": input_ids,
 
129
  }
130
  t = Thread(target=model.generate, kwargs=summary_kwargs)
131
  t.start()
132
+
133
  summary = ""
134
  for new_text in summary_streamer:
135
  summary += new_text
136
  return summary
137
 
138
  def analyze_sentiment(text: str) -> str:
139
+ """Analyze sentiment of the text using a sentiment analysis model."""
140
  result = sentiment_pipeline(text)
141
  sentiment = result[0]['label']
142
  score = result[0]['score']
143
+ return f"🟢 **Sentiment**: {sentiment} (Score: {score:.2f})"
144
 
145
  def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
146
  """
147
+ Generate a response using the Llama model in streaming mode.
148
  """
149
+ # Build conversation history
150
  conversation = []
151
  for user, assistant in chat_history:
152
  conversation.extend([
153
  {"role": "user", "content": user},
154
  {"role": "assistant", "content": assistant},
155
  ])
156
+ conversation.append({"role": "user", "content": prompt}) # Add user's message
157
+
158
+ # Prepare input_ids from tokenizer
159
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
160
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
161
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # Truncate input if too long
162
+ gr.Warning(f"Truncated conversation due to exceeding {MAX_INPUT_TOKEN_LENGTH} tokens.")
163
+ input_ids = input_ids.to(device)
164
+
165
+ # Initialize streamer for real-time text generation
166
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
167
  generate_kwargs = {
168
  "input_ids": input_ids,
 
175
  "num_beams": 1,
176
  "repetition_penalty": repetition_penalty,
177
  }
178
+ t = Thread(target=model.generate, kwargs=generate_kwargs) # Create thread for text generation
179
  t.start()
180
+
181
+ # Stream generated text
182
  outputs = []
183
  for text in streamer:
184
  outputs.append(text)
 
187
  @lru_cache(maxsize=128)
188
  def process_query(query: str) -> Dict[str, Any]:
189
  """
190
+ Determine which function to call based on the user's query.
191
  """
192
+ # Define keywords or patterns to identify functions
193
+ web_search_keywords = ["search", "find", "lookup", "google"]
194
+ general_query_keywords = ["explain", "describe", "tell me about", "what is", "how to"]
195
+ summarize_keywords = ["summarize", "summarise", "brief", "short"]
196
+ sentiment_keywords = ["emotion", "mood", "sentiment", "analyze sentiment"]
197
+ train_keywords = ["train"]
198
+
199
+ query_lower = query.lower()
200
+
201
  if any(keyword in query_lower for keyword in web_search_keywords):
202
  function_name = "web_search"
203
  arguments = {"query": query}
 
216
  else:
217
  function_name = "hard_query"
218
  arguments = {"prompt": query}
219
+
220
  return {
221
  "name": function_name,
222
  "arguments": arguments
 
224
 
225
  def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
226
  """
227
+ Execute the appropriate function based on the function call.
228
  """
229
  function_name = function_call["name"]
230
  arguments = function_call["arguments"]
231
+
232
  if function_name == "web_search":
233
  query = arguments["query"]
234
+ yield "🔍 Performing web search..."
235
  web_results = search(query)
236
  if not web_results:
237
+ yield "⚠️ No results found."
238
  return
239
+ # Summarize search results
240
+ web_summary = '\n\n'.join([f"🔗 **Link**: {res['link']}\n📝 **Description**: {res['text']}" for res in web_results if res["text"] != "Could not retrieve content."])
241
  if not web_summary:
242
+ web_summary = "⚠️ Could not retrieve content from search results."
243
+
244
+ # Return search results to user
245
+ yield "📄 **Search Results:**\n" + web_summary
246
+
247
  elif function_name == "summarize_query":
248
+ # When user requests summarization, perform search and then summarize
249
  query = arguments["prompt"]
250
+ yield "🔍 Performing search for summarization..."
251
  web_results = search(query)
252
  if not web_results:
253
+ yield "⚠️ No results found to summarize."
254
  return
255
+ # Combine content from search results for summarization
256
+ combined_text = ' '.join([res['text'] for res in web_results if res['text'] != "Could not retrieve content."])
257
  if not combined_text:
258
+ yield "⚠️ No content available to summarize."
259
  return
260
+ # Summarize the combined content
261
+ yield "📝 Summarizing information..."
262
  summary = summarize_text(combined_text)
263
+ yield "📄 **Summary:**\n" + summary
264
+
265
  elif function_name == "sentiment_analysis":
266
  prompt_text = arguments["prompt"]
267
+ yield "📊 Analyzing sentiment..."
268
  sentiment = analyze_sentiment(prompt_text)
269
  yield sentiment
270
+
271
  elif function_name == "train_model":
272
  prompt_text = arguments["prompt"]
273
+ yield "📊 Training the model..."
274
  training_result = run_training()
275
  yield training_result
276
+
277
  elif function_name in ["general_query", "hard_query"]:
278
  prompt_text = arguments["prompt"]
279
+ yield "🤖 Generating response..."
280
+ # Generate response using the Llama model
281
  response_generator = generate_response(
282
  prompt=prompt_text,
283
  chat_history=chat_history,
 
289
  )
290
  for response in response_generator:
291
  yield response
292
+
293
  else:
294
+ yield "⚠️ Unrecognized function call."
295
 
296
+ # ---------------------------- Training Setup ---------------------------- #
297
 
298
+ # Checkpoint directory
299
  CHECKPOINT_DIR = "./checkpoints"
300
  if not os.path.exists(CHECKPOINT_DIR):
301
  os.makedirs(CHECKPOINT_DIR)
302
 
303
+ # Load Dataset (CPU)
304
  dataset = load_dataset('vntc/wiki-mini-corpus')
305
 
306
+ # Split Dataset into train and validation (CPU)
307
  split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
308
  train_dataset = split_dataset['train']
309
  validation_dataset = split_dataset['test']
310
 
311
+ # Text Preprocessing (CPU)
312
  def preprocess_function(examples):
313
  passages = [passage.lower().strip() for passage in examples['passage']]
314
  return {'passage': passages}
 
316
  processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
317
  processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
318
 
319
+ # Ensure tokenizer has pad_token
320
  if tokenizer.pad_token is None:
321
  tokenizer.pad_token = tokenizer.eos_token
322
 
 
326
  padding='max_length',
327
  truncation=True,
328
  max_length=512,
329
+ return_tensors="pt"
330
  )
331
 
332
  tokenized_train = processed_train.map(tokenize_function, batched=True)
333
  tokenized_validation = processed_validation.map(tokenize_function, batched=True)
334
 
335
+ # Add 'labels' field (CPU)
336
  def add_labels(examples):
337
  examples['labels'] = examples['input_ids'].copy()
338
  return examples
 
340
  tokenized_train = tokenized_train.map(add_labels, batched=True)
341
  tokenized_validation = tokenized_validation.map(add_labels, batched=True)
342
 
343
+ # Remove unnecessary columns (CPU)
344
  tokenized_train = tokenized_train.remove_columns(['passage'])
345
  tokenized_validation = tokenized_validation.remove_columns(['passage'])
346
 
347
+ # Set format for PyTorch (CPU)
348
  tokenized_train.set_format('torch')
349
  tokenized_validation.set_format('torch')
350
 
351
+ # Create DatasetDict (CPU)
352
  final_dataset = {
353
  'train': tokenized_train,
354
  'validation': tokenized_validation
355
  }
356
 
357
+ # Define TrainerCallback to Save Checkpoints Faster
358
  class SaveCheckpointCallback(TrainerCallback):
359
  def on_step_end(self, args, state, control, **kwargs):
360
  if state.global_step % args.save_steps == 0 and state.global_step != 0:
361
  checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
362
+ print(f"Saving checkpoint at: {checkpoint_path}")
363
+ trainer = kwargs['trainer'] # Access trainer from kwargs
364
  trainer.save_model(checkpoint_path)
365
+ return control # Return current control object
366
 
367
+ # Load pretrained model with LoRA
368
  pretrained = AutoModelForCausalLM.from_pretrained(
369
  model_id,
370
  device_map="auto",
 
374
 
375
  data_collator = DataCollatorForLanguageModeling(
376
  tokenizer=tokenizer,
377
+ mlm=False, # Causal LM
378
  pad_to_multiple_of=8
379
  )
380
 
381
  def get_step_done() -> int:
382
  """
383
+ Get the number of training steps completed from the latest checkpoint.
384
 
385
  Returns:
386
+ int: Number of steps completed. Returns 0 if no checkpoint is found.
387
  """
388
  checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')]
389
  if not checkpoints:
390
  return 0
391
  try:
392
+ # Find the latest checkpoint based on step number
393
  latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[1]))
394
  step_done = int(latest_checkpoint.split('-')[1])
395
  return step_done
396
  except (IndexError, ValueError) as e:
397
+ print(f"Error parsing checkpoint name: {e}")
398
  return 0
399
 
400
+ # Load and Configure LoRA (GPU)
 
401
  lora_config = LoraConfig(
402
  r=8,
403
  lora_alpha=32,
 
412
  @spaces.GPU(duration=30, queue=False)
413
  def run_training() -> str:
414
  """
415
+ Train the model using GPU with time constraints.
416
 
417
  Returns:
418
+ str: Training result message.
419
  """
420
 
421
+ # TrainingArguments Configuration (GPU)
422
  training_args = TrainingArguments(
423
  output_dir=CHECKPOINT_DIR,
424
  per_device_train_batch_size=4,
425
  per_device_eval_batch_size=4,
426
  gradient_accumulation_steps=8,
427
  num_train_epochs=3,
428
+ max_steps=300, # Total training steps
429
  learning_rate=3e-4,
430
  weight_decay=0.01,
431
+ logging_steps=1, # Log every step
432
+ eval_strategy="steps", # Evaluate every few steps
433
+ eval_steps=5, # Evaluate every 5 steps
434
+ save_strategy="steps", # Save checkpoint every few steps
435
+ save_steps=5, # Save every 5 steps
436
+ save_total_limit=5, # Limit number of saved checkpoints
437
  fp16=True,
438
  report_to="none",
439
  load_best_model_at_end=True,
440
  )
441
 
442
+ # Initialize Trainer (GPU)
443
  trainer = Trainer(
444
  model=pretrained_model,
445
  args=training_args,
 
447
  eval_dataset=final_dataset['validation'],
448
  tokenizer=tokenizer,
449
  data_collator=data_collator,
450
+ callbacks=[SaveCheckpointCallback()], # Add callback
451
  )
452
 
453
+ # Check for existing checkpoint
454
  steps_done = get_step_done()
455
  if steps_done > 0:
456
+ # Determine the latest checkpoint based on step number
457
  latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
458
  if os.path.exists(latest_checkpoint):
459
+ print(f"Resuming training from checkpoint: {latest_checkpoint}")
460
  trainer.train(resume_from_checkpoint=latest_checkpoint)
461
  else:
462
+ print(f"Checkpoint {latest_checkpoint} does not exist. Starting training from scratch.")
463
  trainer.train()
464
  else:
465
  trainer.train()
466
 
467
+ # Save checkpoint after training
468
  trainer.save_model(CHECKPOINT_DIR)
469
+ return "Training completed or resumed from checkpoint."
470
 
471
+ # Automatic Function to Call Training Repeatedly
472
  @spaces.GPU(duration=30, queue=False)
473
  def continuous_training(total_steps=300, steps_per_call=50):
474
  """
475
+ Automatically call `run_training` to complete the training process.
476
 
477
  Args:
478
+ total_steps (int): Desired total training steps.
479
+ steps_per_call (int): Training steps per function call.
480
  """
481
  steps_done = get_step_done()
482
  while steps_done < total_steps:
483
  remaining_steps = total_steps - steps_done
484
  current_steps = min(steps_per_call, remaining_steps)
485
+ print(f"Starting training for {current_steps} steps.")
486
 
487
+ # Update TrainingArguments for current_steps
488
  training_args = TrainingArguments(
489
  output_dir=CHECKPOINT_DIR,
490
  per_device_train_batch_size=4,
491
  per_device_eval_batch_size=4,
492
  gradient_accumulation_steps=8,
493
+ num_train_epochs=1, # Train for one epoch
494
  max_steps=current_steps,
495
  learning_rate=3e-4,
496
  weight_decay=0.01,
 
505
  load_best_model_at_end=True,
506
  )
507
 
508
+ # Initialize Trainer with updated TrainingArguments
509
  trainer = Trainer(
510
  model=pretrained_model,
511
  args=training_args,
 
516
  callbacks=[SaveCheckpointCallback()],
517
  )
518
 
519
+ # Resume training from latest checkpoint
520
  if steps_done > 0:
521
  latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
522
  if os.path.exists(latest_checkpoint):
523
+ print(f"Resuming training from checkpoint: {latest_checkpoint}")
524
  trainer.train(resume_from_checkpoint=latest_checkpoint)
525
  else:
526
+ print(f"Checkpoint {latest_checkpoint} does not exist. Starting training from scratch.")
527
  trainer.train()
528
  else:
529
  trainer.train()
530
 
531
  steps_done = get_step_done()
532
+ print(f"Trained {steps_done} / {total_steps} steps.")
533
 
534
+ # Check if desired steps are achieved
535
  if steps_done >= total_steps:
536
+ print("Completed the entire training process.")
537
  break
538
 
539
+ # Wait before the next training call
540
+ time.sleep(2) # Adjust wait time as needed
541
 
542
+ # ---------------------------- Gradio Interface ---------------------------- #
543
 
544
  @spaces.GPU(duration=30, queue=False)
545
  def generate(
 
552
  repetition_penalty: float = 1.2,
553
  ) -> Iterator[str]:
554
  """
555
+ Main function to handle user input and generate responses.
556
  """
557
+ # Notify about query analysis
558
+ yield "🔍 Analyzing your query..."
559
 
560
+ # Determine which function to call based on user's message
561
  function_call = process_query(message)
562
 
563
+ # Notify about the selected function
564
  if function_call["name"] == "web_search":
565
+ yield "🛠️ Selected function: Web Search."
566
  elif function_call["name"] == "summarize_query":
567
+ yield "🛠️ Selected function: Text Summarization."
568
  elif function_call["name"] == "sentiment_analysis":
569
+ yield "🛠️ Selected function: Sentiment Analysis."
570
  elif function_call["name"] in ["general_query", "hard_query"]:
571
+ yield "🛠️ Selected function: Answering Questions."
572
  elif function_call["name"] == "train_model":
573
+ yield "🛠️ Selected function: Model Training."
574
  else:
575
+ yield "⚠️ Could not determine an appropriate function."
576
 
577
+ # Execute the function call and generate responses
578
  response_iterator = handle_functions(
579
  function_call=function_call,
580
  prompt=message,
 
589
  for response in response_iterator:
590
  yield response
591
 
592
+ # Define examples to guide users
593
  EXAMPLES = [
594
+ ["Hello! How are you?"],
595
+ ["Can you briefly explain the Python programming language?"],
596
+ ["Explain the plot of Cinderella in one sentence."],
597
+ ["How many hours does a man need to eat a helicopter?"],
598
+ ["Write a 100-word article on 'Benefits of Open Source in AI Research'"],
599
+ ["Search and provide me with the latest news on renewable energy."],
600
+ ["Find information about the Great Barrier Reef coral reefs."],
601
+ ["Summarize information about artificial intelligence."],
602
+ ["Analyze the sentiment of the following text: I am very happy to meet you today!"],
603
+ ["Train the model!"],
604
  ]
605
 
606
+ # Configure Gradio chat interface with enhanced UI
607
  chat_interface = gr.ChatInterface(
608
+ fn=generate, # Function called on user interaction
609
  additional_inputs=[
610
  gr.Slider(
611
+ label="Max New Tokens",
612
  minimum=1,
613
  maximum=MAX_MAX_NEW_TOKENS,
614
  step=1,
615
  value=DEFAULT_MAX_NEW_TOKENS,
616
  ),
617
  gr.Slider(
618
+ label="Temperature",
619
  minimum=0.1,
620
  maximum=4.0,
621
  step=0.1,
622
  value=0.6,
623
  ),
624
  gr.Slider(
625
+ label="Top-p (Nucleus Sampling)",
626
  minimum=0.05,
627
  maximum=1.0,
628
  step=0.05,
 
636
  value=50,
637
  ),
638
  gr.Slider(
639
+ label="Repetition Penalty",
640
  minimum=1.0,
641
  maximum=2.0,
642
  step=0.05,
643
  value=1.2,
644
  ),
645
  ],
646
+ stop_btn=None, # No stop button
647
+ examples=EXAMPLES, # Display examples to users
648
+ cache_examples=False, # Do not cache examples
649
  title="🤖 OpenGPT-4o Chatbot",
650
+ description="A powerful AI assistant using the local Llama-3.2 model with web search, text summarization, and sentiment analysis functionalities.",
651
+ theme="default", # Customize theme as needed
652
  )
653
 
654
+ # Create the main Gradio interface with custom CSS
655
  with gr.Blocks(css="""
656
  .gradio-container {
657
+ background-color: #f0f2f5; /* Light background color */
658
  }
659
  .gradio-container h1 {
660
+ color: #4a90e2; /* Blue color for title */
661
  }
662
  .gradio-container .gr-button {
663
+ background-color: #4a90e2; /* Blue color for buttons */
664
+ color: white; /* White text on buttons */
665
  }
666
  .gradio-container .gr-slider__label {
667
+ color: #333333; /* Dark text for slider labels */
668
  }
669
  .gradio-container .gr-chatbot {
670
+ border: 2px solid #4a90e2; /* Blue border for chatbot */
671
+ border-radius: 10px; /* Rounded corners for chatbot */
672
+ padding: 10px; /* Inner padding for chatbot */
673
+ background-color: #ffffff; /* White background for chatbot */
674
  }
675
  """, fill_height=True) as demo:
676
+ gr.Markdown(DESCRIPTION) # Display description
677
+ chat_interface.render() # Render chat interface
 
 
678
 
679
  if __name__ == "__main__":
680
+ demo.queue(max_size=30).launch() # Launch Gradio app with a queue size of 30