hoduyquocbao commited on
Commit
e7a08ba
·
1 Parent(s): 9002cd7

limit steps

Browse files
Files changed (1) hide show
  1. app.py +133 -183
app.py CHANGED
@@ -46,6 +46,7 @@ model = AutoModelForCausalLM.from_pretrained(
46
  model_id,
47
  device_map="auto",
48
  torch_dtype=torch.bfloat16, # Sử dụng dtype phù hợp để tiết kiệm bộ nhớ
 
49
  )
50
  model.to(device) # Di chuyển mô hình tới thiết bị đã chọn
51
  model.eval() # Đặt mô hình ở chế độ đánh giá
@@ -53,6 +54,70 @@ model.eval() # Đặt mô hình ở chế độ đánh giá
53
  # Khởi tạo pipeline phân tích tâm lý
54
  sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # ---------------------------- Định Nghĩa Hàm ---------------------------- #
57
 
58
  @lru_cache(maxsize=128)
@@ -293,183 +358,8 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
293
  else:
294
  yield "⚠️ Lời gọi hàm không được nhận dạng."
295
 
296
- # ---------------------------- Giao Diện Gradio ---------------------------- #
297
-
298
- @spaces.GPU(duration=30, queue=False)
299
- def generate(
300
- message: str,
301
- chat_history: List[Tuple[str, str]],
302
- max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
303
- temperature: float = 0.6,
304
- top_p: float = 0.9,
305
- top_k: int = 50,
306
- repetition_penalty: float = 1.2,
307
- ) -> Iterator[str]:
308
- """
309
- Hàm chính để xử lý đầu vào của người dùng và tạo phản hồi.
310
- """
311
- # Thông báo về việc phân tích đầu vào
312
- yield "🔍 Đang phân tích truy vấn của bạn..."
313
-
314
- # Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
315
- function_call = process_query(message)
316
-
317
- # Thông báo về hàm được chọn
318
- if function_call["name"] == "web_search":
319
- yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
320
- elif function_call["name"] == "summarize_query":
321
- yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
322
- elif function_call["name"] == "sentiment_analysis":
323
- yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
324
- elif function_call["name"] == "train_model":
325
- yield "🛠️ Đã chọn chức năng: Huấn luyện mô hình."
326
- elif function_call["name"] in ["general_query", "hard_query"]:
327
- yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
328
- else:
329
- yield "⚠️ Không thể xác định chức năng phù hợp."
330
-
331
- # Xử lý lời gọi hàm và sinh phản hồi tương ứng
332
- response_iterator = handle_functions(
333
- function_call=function_call,
334
- prompt=message,
335
- chat_history=chat_history,
336
- max_new_tokens=max_new_tokens,
337
- temperature=temperature,
338
- top_p=top_p,
339
- top_k=top_k,
340
- repetition_penalty=repetition_penalty
341
- )
342
-
343
- for response in response_iterator:
344
- yield response
345
-
346
- # Định nghĩa các ví dụ để hướng dẫn người dùng
347
- EXAMPLES = [
348
- ["Xin chào! Bạn khỏe không?"],
349
- ["Bạn có thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"],
350
- ["Giải thích cốt truyện của Cô bé Lọ Lem trong một câu."],
351
- ["Một người đàn ông cần bao nhiêu giờ để ăn một chiếc máy bay trực thăng?"],
352
- ["Viết một bài báo 100 từ về 'Lợi ích của mã nguồn mở trong nghiên cứu AI'"],
353
- ["Tìm và cung cấp cho tôi tin tức mới nhất về năng lượng tái tạo."],
354
- ["Tìm thông tin về Rạn san hô Great Barrier Reef."],
355
- ["Tóm tắt nội dung về trí tuệ nhân tạo."],
356
- ["Phân tích tâm lý của đoạn văn sau: Tôi rất vui khi được gặp bạn hôm nay!"],
357
- ["Huấn luyện mô hình với dữ liệu mới để cải thiện khả năng hiểu tiếng Việt."], # Ví dụ mới thêm
358
- ]
359
-
360
- # Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
361
- chat_interface = gr.ChatInterface(
362
- fn=generate, # Hàm được gọi khi có tương tác từ người dùng
363
- additional_inputs=[
364
- gr.Slider(
365
- label="Số token mới tối đa",
366
- minimum=1,
367
- maximum=MAX_MAX_NEW_TOKENS,
368
- step=1,
369
- value=DEFAULT_MAX_NEW_TOKENS,
370
- ),
371
- gr.Slider(
372
- label="Nhiệt độ",
373
- minimum=0.1,
374
- maximum=4.0,
375
- step=0.1,
376
- value=0.6,
377
- ),
378
- gr.Slider(
379
- label="Top-p (nucleus sampling)",
380
- minimum=0.05,
381
- maximum=1.0,
382
- step=0.05,
383
- value=0.9,
384
- ),
385
- gr.Slider(
386
- label="Top-k",
387
- minimum=1,
388
- maximum=1000,
389
- step=1,
390
- value=50,
391
- ),
392
- gr.Slider(
393
- label="Hình phạt sự lặp lại",
394
- minimum=1.0,
395
- maximum=2.0,
396
- step=0.05,
397
- value=1.2,
398
- ),
399
- ],
400
- stop_btn=None, # Không có nút dừng
401
- examples=EXAMPLES, # Các ví dụ được hiển thị cho người dùng
402
- cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ
403
- title="🤖 OpenGPT-4o Chatbot",
404
- description="Một trợ lý AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản, phân tích tâm lý và huấn luyện mô hình.",
405
- theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
406
- )
407
-
408
  # ---------------------------- Huấn Luyện Mô Hình ---------------------------- #
409
 
410
- # Đường dẫn lưu checkpoint
411
- CHECKPOINT_DIR = "./checkpoints"
412
- if not os.path.exists(CHECKPOINT_DIR):
413
- os.makedirs(CHECKPOINT_DIR)
414
-
415
- # Tải Dataset (CPU)
416
- dataset = load_dataset('vntc/wiki-mini-corpus')
417
-
418
- # Chia Dataset thành train và validation (CPU)
419
- split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
420
- train_dataset = split_dataset['train']
421
- validation_dataset = split_dataset['test']
422
-
423
- # Tiền Xử Lý Văn Bản (CPU)
424
- def preprocess_function(examples):
425
- passages = [passage.lower().strip() for passage in examples['passage']]
426
- return {'passage': passages}
427
-
428
- processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
429
- processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
430
-
431
- # Tokenization (CPU)
432
- model_name = "meta-llama/Llama-3.2-3B-Instruct"
433
- tokenizer = AutoTokenizer.from_pretrained(model_name)
434
-
435
- # Đảm bảo tokenizer có pad_token
436
- if tokenizer.pad_token is None:
437
- tokenizer.pad_token = tokenizer.eos_token
438
-
439
- def tokenize_function(examples):
440
- return tokenizer(
441
- examples['passage'],
442
- padding='max_length',
443
- truncation=True,
444
- max_length=512,
445
- )
446
-
447
- tokenized_train = processed_train.map(tokenize_function, batched=True)
448
- tokenized_validation = processed_validation.map(tokenize_function, batched=True)
449
-
450
- # Thêm trường 'labels' (CPU)
451
- def add_labels(examples):
452
- examples['labels'] = examples['input_ids'].copy()
453
- return examples
454
-
455
- tokenized_train = tokenized_train.map(add_labels, batched=True)
456
- tokenized_validation = tokenized_validation.map(add_labels, batched=True)
457
-
458
- # Loại bỏ các cột không cần thiết (CPU)
459
- tokenized_train = tokenized_train.remove_columns(['passage'])
460
- tokenized_validation = tokenized_validation.remove_columns(['passage'])
461
-
462
- # Định dạng dữ liệu cho PyTorch (CPU)
463
- tokenized_train.set_format('torch')
464
- tokenized_validation.set_format('torch')
465
-
466
- # Tạo DatasetDict (CPU)
467
- final_dataset = {
468
- 'train': tokenized_train,
469
- 'validation': tokenized_validation
470
- }
471
-
472
- # Định Nghĩa TrainerCallback để Lưu Checkpoint Nhanh Hơn
473
  class SaveCheckpointCallback(TrainerCallback):
474
  def on_step_end(self, args, state, control, **kwargs):
475
  if state.global_step % args.save_steps == 0 and state.global_step != 0:
@@ -482,7 +372,6 @@ class SaveCheckpointCallback(TrainerCallback):
482
  print("Không thể truy cập 'trainer' từ kwargs.")
483
  return control
484
 
485
- # Định Nghĩa Hàm Huấn Luyện với Decorator @spaces.GPU
486
  @spaces.GPU(duration=60, queue=False) # Tăng duration lên 60 giây
487
  def run_training(steps_per_call=5):
488
  """
@@ -490,12 +379,12 @@ def run_training(steps_per_call=5):
490
  Huấn luyện 5 bước mỗi lần gọi.
491
  """
492
  # Tải và Cấu Hình Mô Hình với LoRA (GPU)
493
- model = AutoModelForCausalLM.from_pretrained(
494
- model_name,
495
- device_map="auto",
496
- torch_dtype=torch.float16,
497
- load_in_8bit=False
498
- )
499
 
500
  lora_config = LoraConfig(
501
  r=8,
@@ -560,7 +449,6 @@ def run_training(steps_per_call=5):
560
  trainer.save_model(CHECKPOINT_DIR)
561
  return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
562
 
563
- # Hàm Huấn Luyện
564
  def train_model():
565
  """
566
  Hàm huấn luyện mô hình, huấn luyện 5 bước mỗi lần và đảm bảo lưu checkpoint.
@@ -571,6 +459,68 @@ def train_model():
571
 
572
  # ---------------------------- Giao Diện Gradio ---------------------------- #
573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  # Tạo giao diện chính của Gradio với CSS tùy chỉnh
575
  with gr.Blocks(css="""
576
  .gradio-container {
 
46
  model_id,
47
  device_map="auto",
48
  torch_dtype=torch.bfloat16, # Sử dụng dtype phù hợp để tiết kiệm bộ nhớ
49
+ load_in_8bit=False
50
  )
51
  model.to(device) # Di chuyển mô hình tới thiết bị đã chọn
52
  model.eval() # Đặt mô hình ở chế độ đánh giá
 
54
  # Khởi tạo pipeline phân tích tâm lý
55
  sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
56
 
57
+ # ---------------------------- Tải và Tiền Xử Lý Dữ Liệu ---------------------------- #
58
+
59
+ # Đường dẫn lưu checkpoint
60
+ CHECKPOINT_DIR = "./checkpoints"
61
+ if not os.path.exists(CHECKPOINT_DIR):
62
+ os.makedirs(CHECKPOINT_DIR)
63
+
64
+ # Tải Dataset (CPU)
65
+ dataset = load_dataset('vntc/wiki-mini-corpus')
66
+
67
+ # Chia Dataset thành train và validation (CPU)
68
+ split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
69
+ train_dataset = split_dataset['train']
70
+ validation_dataset = split_dataset['test']
71
+
72
+ # Tiền Xử Lý Văn Bản (CPU)
73
+ def preprocess_function(examples):
74
+ passages = [passage.lower().strip() for passage in examples['passage']]
75
+ return {'passage': passages}
76
+
77
+ processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
78
+ processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
79
+
80
+ # Tokenization (CPU)
81
+ model_name = "meta-llama/Llama-3.2-3B-Instruct"
82
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
83
+
84
+ # Đảm bảo tokenizer có pad_token
85
+ if tokenizer.pad_token is None:
86
+ tokenizer.pad_token = tokenizer.eos_token
87
+
88
+ def tokenize_function(examples):
89
+ return tokenizer(
90
+ examples['passage'],
91
+ padding='max_length',
92
+ truncation=True,
93
+ max_length=512,
94
+ )
95
+
96
+ tokenized_train = processed_train.map(tokenize_function, batched=True)
97
+ tokenized_validation = processed_validation.map(tokenize_function, batched=True)
98
+
99
+ # Thêm trường 'labels' (CPU)
100
+ def add_labels(examples):
101
+ examples['labels'] = examples['input_ids'].copy()
102
+ return examples
103
+
104
+ tokenized_train = tokenized_train.map(add_labels, batched=True)
105
+ tokenized_validation = tokenized_validation.map(add_labels, batched=True)
106
+
107
+ # Loại bỏ các cột không cần thiết (CPU)
108
+ tokenized_train = tokenized_train.remove_columns(['passage'])
109
+ tokenized_validation = tokenized_validation.remove_columns(['passage'])
110
+
111
+ # Định dạng dữ liệu cho PyTorch (CPU)
112
+ tokenized_train.set_format('torch')
113
+ tokenized_validation.set_format('torch')
114
+
115
+ # Tạo DatasetDict (CPU)
116
+ final_dataset = {
117
+ 'train': tokenized_train,
118
+ 'validation': tokenized_validation
119
+ }
120
+
121
  # ---------------------------- Định Nghĩa Hàm ---------------------------- #
122
 
123
  @lru_cache(maxsize=128)
 
358
  else:
359
  yield "⚠️ Lời gọi hàm không được nhận dạng."
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  # ---------------------------- Huấn Luyện Mô Hình ---------------------------- #
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  class SaveCheckpointCallback(TrainerCallback):
364
  def on_step_end(self, args, state, control, **kwargs):
365
  if state.global_step % args.save_steps == 0 and state.global_step != 0:
 
372
  print("Không thể truy cập 'trainer' từ kwargs.")
373
  return control
374
 
 
375
  @spaces.GPU(duration=60, queue=False) # Tăng duration lên 60 giây
376
  def run_training(steps_per_call=5):
377
  """
 
379
  Huấn luyện 5 bước mỗi lần gọi.
380
  """
381
  # Tải và Cấu Hình Mô Hình với LoRA (GPU)
382
+ # model = AutoModelForCausalLM.from_pretrained(
383
+ # model_name,
384
+ # device_map="auto",
385
+ # torch_dtype=torch.float16,
386
+ # load_in_8bit=False
387
+ # )
388
 
389
  lora_config = LoraConfig(
390
  r=8,
 
449
  trainer.save_model(CHECKPOINT_DIR)
450
  return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
451
 
 
452
  def train_model():
453
  """
454
  Hàm huấn luyện mô hình, huấn luyện 5 bước mỗi lần và đảm bảo lưu checkpoint.
 
459
 
460
  # ---------------------------- Giao Diện Gradio ---------------------------- #
461
 
462
+ # Định nghĩa các ví dụ để hướng dẫn người dùng
463
+ EXAMPLES = [
464
+ ["Xin chào! Bạn khỏe không?"],
465
+ ["Bạn có thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"],
466
+ ["Giải thích cốt truyện của Cô bé Lọ Lem trong một câu."],
467
+ ["Một người đàn ông cần bao nhiêu giờ để ăn một chiếc máy bay trực thăng?"],
468
+ ["Viết một bài báo 100 từ về 'Lợi ích của mã nguồn mở trong nghiên cứu AI'"],
469
+ ["Tìm và cung cấp cho tôi tin tức mới nhất về năng lượng tái tạo."],
470
+ ["Tìm thông tin về Rạn san hô Great Barrier Reef."],
471
+ ["Tóm tắt nội dung về trí tuệ nhân tạo."],
472
+ ["Phân tích tâm lý của đoạn văn sau: Tôi rất vui khi được gặp bạn hôm nay!"],
473
+ ["Huấn luyện mô hình với dữ liệu mới để cải thiện khả năng hiểu tiếng Việt."], # Ví dụ mới thêm
474
+ ]
475
+
476
+ # Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
477
+ chat_interface = gr.ChatInterface(
478
+ fn=generate, # Hàm được gọi khi có tương tác từ người dùng
479
+ additional_inputs=[
480
+ gr.Slider(
481
+ label="Số token mới tối đa",
482
+ minimum=1,
483
+ maximum=MAX_MAX_NEW_TOKENS,
484
+ step=1,
485
+ value=DEFAULT_MAX_NEW_TOKENS,
486
+ ),
487
+ gr.Slider(
488
+ label="Nhiệt độ",
489
+ minimum=0.1,
490
+ maximum=4.0,
491
+ step=0.1,
492
+ value=0.6,
493
+ ),
494
+ gr.Slider(
495
+ label="Top-p (nucleus sampling)",
496
+ minimum=0.05,
497
+ maximum=1.0,
498
+ step=0.05,
499
+ value=0.9,
500
+ ),
501
+ gr.Slider(
502
+ label="Top-k",
503
+ minimum=1,
504
+ maximum=1000,
505
+ step=1,
506
+ value=50,
507
+ ),
508
+ gr.Slider(
509
+ label="Hình phạt sự lặp lại",
510
+ minimum=1.0,
511
+ maximum=2.0,
512
+ step=0.05,
513
+ value=1.2,
514
+ ),
515
+ ],
516
+ stop_btn=None, # Không có nút dừng
517
+ examples=EXAMPLES, # Các ví dụ được hiển thị cho người dùng
518
+ cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ
519
+ title="🤖 OpenGPT-4o Chatbot",
520
+ description="Một trợ lý AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản, phân tích tâm lý và huấn luyện mô hình.",
521
+ theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
522
+ )
523
+
524
  # Tạo giao diện chính của Gradio với CSS tùy chỉnh
525
  with gr.Blocks(css="""
526
  .gradio-container {