hoduyquocbao commited on
Commit
9002cd7
·
1 Parent(s): 11c6b76

update functions call

Browse files
Files changed (1) hide show
  1. app.py +65 -51
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  from threading import Thread
4
  from typing import Iterator, List, Tuple, Dict, Any
@@ -6,7 +5,16 @@ from typing import Iterator, List, Tuple, Dict, Any
6
  import gradio as gr
7
  import spaces
8
  import torch
9
- from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling, TrainerCallback,AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
 
 
 
 
 
 
 
 
 
10
  from bs4 import BeautifulSoup
11
  import requests
12
  import json
@@ -184,10 +192,14 @@ def process_query(query: str) -> Dict[str, Any]:
184
  general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
185
  summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
186
  sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]
 
187
 
188
  query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh
189
-
190
- if any(keyword in query_lower for keyword in web_search_keywords):
 
 
 
191
  function_name = "web_search"
192
  arguments = {"query": query}
193
  elif any(keyword in query_lower for keyword in summarize_keywords):
@@ -202,7 +214,7 @@ def process_query(query: str) -> Dict[str, Any]:
202
  else:
203
  function_name = "hard_query"
204
  arguments = {"prompt": query}
205
-
206
  return {
207
  "name": function_name,
208
  "arguments": arguments
@@ -214,7 +226,7 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
214
  """
215
  function_name = function_call["name"]
216
  arguments = function_call["arguments"]
217
-
218
  if function_name == "web_search":
219
  query = arguments["query"]
220
  yield "🔍 Đang thực hiện tìm kiếm trên web..."
@@ -226,10 +238,10 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
226
  web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
227
  if not web_summary:
228
  web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
229
-
230
  # Trả về kết quả tìm kiếm cho người dùng
231
  yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
232
-
233
  elif function_name == "summarize_query":
234
  # Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
235
  query = arguments["prompt"]
@@ -247,13 +259,21 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
247
  yield "📝 Đang tóm tắt thông tin..."
248
  summary = summarize_text(combined_text)
249
  yield "📄 **Tóm tắt:**\n" + summary
250
-
251
  elif function_name == "sentiment_analysis":
252
  prompt_text = arguments["prompt"]
253
  yield "📊 Đang phân tích tâm lý..."
254
  sentiment = analyze_sentiment(prompt_text)
255
  yield sentiment
256
-
 
 
 
 
 
 
 
 
257
  elif function_name in ["general_query", "hard_query"]:
258
  prompt_text = arguments["prompt"]
259
  yield "🤖 Đang tạo phản hồi..."
@@ -269,7 +289,7 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
269
  )
270
  for response in response_generator:
271
  yield response
272
-
273
  else:
274
  yield "⚠️ Lời gọi hàm không được nhận dạng."
275
 
@@ -291,23 +311,23 @@ def generate(
291
  # Thông báo về việc phân tích đầu vào
292
  yield "🔍 Đang phân tích truy vấn của bạn..."
293
 
294
-
295
  # Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
296
  function_call = process_query(message)
297
-
298
  # Thông báo về hàm được chọn
299
  if function_call["name"] == "web_search":
300
  yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
301
  elif function_call["name"] == "summarize_query":
302
  yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
303
  elif function_call["name"] == "sentiment_analysis":
304
- continuous_training(total_steps=300, steps_per_call=50)
305
  yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
 
 
306
  elif function_call["name"] in ["general_query", "hard_query"]:
307
  yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
308
  else:
309
  yield "⚠️ Không thể xác định chức năng phù hợp."
310
-
311
  # Xử lý lời gọi hàm và sinh phản hồi tương ứng
312
  response_iterator = handle_functions(
313
  function_call=function_call,
@@ -319,7 +339,7 @@ def generate(
319
  top_k=top_k,
320
  repetition_penalty=repetition_penalty
321
  )
322
-
323
  for response in response_iterator:
324
  yield response
325
 
@@ -334,6 +354,7 @@ EXAMPLES = [
334
  ["Tìm thông tin về Rạn san hô Great Barrier Reef."],
335
  ["Tóm tắt nội dung về trí tuệ nhân tạo."],
336
  ["Phân tích tâm lý của đoạn văn sau: Tôi rất vui khi được gặp bạn hôm nay!"],
 
337
  ]
338
 
339
  # Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
@@ -380,10 +401,11 @@ chat_interface = gr.ChatInterface(
380
  examples=EXAMPLES, # Các ví dụ được hiển thị cho người dùng
381
  cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ
382
  title="🤖 OpenGPT-4o Chatbot",
383
- description="Một trợ lý AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản phân tích tâm lý.",
384
  theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
385
  )
386
 
 
387
 
388
  # Đường dẫn lưu checkpoint
389
  CHECKPOINT_DIR = "./checkpoints"
@@ -453,15 +475,19 @@ class SaveCheckpointCallback(TrainerCallback):
453
  if state.global_step % args.save_steps == 0 and state.global_step != 0:
454
  checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
455
  print(f"Lưu checkpoint tại: {checkpoint_path}")
456
- trainer = kwargs['trainer'] # Truy cập trainer từ kwargs
457
- trainer.save_model(checkpoint_path)
458
- return control # Trả về đối tượng control hiện tại
 
 
 
459
 
460
  # Định Nghĩa Hàm Huấn Luyện với Decorator @spaces.GPU
461
- @spaces.GPU(duration=30, queue=False)
462
- def run_training():
463
  """
464
  Hàm huấn luyện mô hình sử dụng GPU với thời gian hạn chế.
 
465
  """
466
  # Tải và Cấu Hình Mô Hình với LoRA (GPU)
467
  model = AutoModelForCausalLM.from_pretrained(
@@ -488,15 +514,15 @@ def run_training():
488
  per_device_train_batch_size=4,
489
  per_device_eval_batch_size=4,
490
  gradient_accumulation_steps=8,
491
- num_train_epochs=3,
492
- max_steps=5, # Đặt max_steps tại đây
493
  learning_rate=3e-4,
494
  weight_decay=0.01,
495
- logging_steps=5, # Giảm số bước logging để theo dõi thường xuyên hơn
496
  eval_strategy="steps", # Đánh giá sau mỗi vài bước
497
- eval_steps=5, # Đánh giá sau mỗi 50 bước
498
  save_strategy="steps", # Lưu checkpoint sau mỗi vài bước
499
- save_steps=5, # Lưu checkpoint sau mỗi 50 bước
500
  save_total_limit=5, # Giới hạn số lượng checkpoint lưu trữ
501
  fp16=True,
502
  report_to="none",
@@ -518,7 +544,7 @@ def run_training():
518
  eval_dataset=final_dataset['validation'],
519
  tokenizer=tokenizer,
520
  data_collator=data_collator,
521
- callbacks=[SaveCheckpointCallback()], # Thêm callback
522
  )
523
 
524
  # Kiểm tra nếu có checkpoint
@@ -534,30 +560,16 @@ def run_training():
534
  trainer.save_model(CHECKPOINT_DIR)
535
  return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
536
 
537
- # Hàm Tự Động Hóa Việc Gọi Lặp Lại Hàm Huấn Luyện
538
- def continuous_training(total_steps=300, steps_per_call=5):
539
  """
540
- Hàm tự động gọi lại `run_training` để hoàn thành quá trình huấn luyện.
541
-
542
- Args:
543
- total_steps (int): Tổng số bước huấn luyện mong muốn.
544
- steps_per_call (int): Số bước huấn luyện mỗi lần gọi hàm.
545
  """
546
- steps_done = 0
547
- while steps_done < total_steps:
548
- print(f"Bắt đầu huấn luyện cho {steps_per_call} bước.")
549
- result = run_training()
550
- print(result)
551
- steps_done += steps_per_call
552
- print(f"Đã huấn luyện {steps_done} / {total_steps} bước.")
553
-
554
- # Kiểm tra nếu đã đạt số bước mong muốn
555
- if steps_done >= total_steps:
556
- print("Đã hoàn thành toàn bộ quá trình huấn luyện.")
557
- break
558
-
559
- # Chờ một khoảng thời gian trước khi gọi lại (tùy thuộc vào yêu cầu của hệ thống)
560
- time.sleep(2) # Thời gian chờ có thể điều chỉnh
561
 
562
  # Tạo giao diện chính của Gradio với CSS tùy chỉnh
563
  with gr.Blocks(css="""
@@ -585,5 +597,7 @@ with gr.Blocks(css="""
585
  gr.DuplicateButton(value="Nhân bản Không gian để sử dụng riêng tư", elem_id="duplicate-button") # Nút nhân bản không gian
586
  chat_interface.render() # Hiển thị giao diện trò chuyện
587
 
 
 
588
  if __name__ == "__main__":
589
- demo.queue(max_size=30).launch() # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là 20
 
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator, List, Tuple, Dict, Any
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import (
9
+ TrainingArguments,
10
+ Trainer,
11
+ DataCollatorForLanguageModeling,
12
+ TrainerCallback,
13
+ AutoModelForCausalLM,
14
+ AutoTokenizer,
15
+ TextIteratorStreamer,
16
+ pipeline
17
+ )
18
  from bs4 import BeautifulSoup
19
  import requests
20
  import json
 
192
  general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
193
  summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
194
  sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]
195
+ train_keywords = ["huấn luyện", "train", "fine-tune", "tinh chỉnh"]
196
 
197
  query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh
198
+
199
+ if any(keyword in query_lower for keyword in train_keywords):
200
+ function_name = "train_model"
201
+ arguments = {"prompt": query}
202
+ elif any(keyword in query_lower for keyword in web_search_keywords):
203
  function_name = "web_search"
204
  arguments = {"query": query}
205
  elif any(keyword in query_lower for keyword in summarize_keywords):
 
214
  else:
215
  function_name = "hard_query"
216
  arguments = {"prompt": query}
217
+
218
  return {
219
  "name": function_name,
220
  "arguments": arguments
 
226
  """
227
  function_name = function_call["name"]
228
  arguments = function_call["arguments"]
229
+
230
  if function_name == "web_search":
231
  query = arguments["query"]
232
  yield "🔍 Đang thực hiện tìm kiếm trên web..."
 
238
  web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
239
  if not web_summary:
240
  web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
241
+
242
  # Trả về kết quả tìm kiếm cho người dùng
243
  yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
244
+
245
  elif function_name == "summarize_query":
246
  # Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
247
  query = arguments["prompt"]
 
259
  yield "📝 Đang tóm tắt thông tin..."
260
  summary = summarize_text(combined_text)
261
  yield "📄 **Tóm tắt:**\n" + summary
262
+
263
  elif function_name == "sentiment_analysis":
264
  prompt_text = arguments["prompt"]
265
  yield "📊 Đang phân tích tâm lý..."
266
  sentiment = analyze_sentiment(prompt_text)
267
  yield sentiment
268
+
269
+ elif function_name == "train_model":
270
+ prompt_text = arguments["prompt"]
271
+ yield "🛠️ Đang bắt đầu quá trình huấn luyện..."
272
+ # Gọi hàm huấn luyện
273
+ training_iterator = train_model()
274
+ for response in training_iterator:
275
+ yield response
276
+
277
  elif function_name in ["general_query", "hard_query"]:
278
  prompt_text = arguments["prompt"]
279
  yield "🤖 Đang tạo phản hồi..."
 
289
  )
290
  for response in response_generator:
291
  yield response
292
+
293
  else:
294
  yield "⚠️ Lời gọi hàm không được nhận dạng."
295
 
 
311
  # Thông báo về việc phân tích đầu vào
312
  yield "🔍 Đang phân tích truy vấn của bạn..."
313
 
 
314
  # Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
315
  function_call = process_query(message)
316
+
317
  # Thông báo về hàm được chọn
318
  if function_call["name"] == "web_search":
319
  yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
320
  elif function_call["name"] == "summarize_query":
321
  yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
322
  elif function_call["name"] == "sentiment_analysis":
 
323
  yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
324
+ elif function_call["name"] == "train_model":
325
+ yield "🛠️ Đã chọn chức năng: Huấn luyện mô hình."
326
  elif function_call["name"] in ["general_query", "hard_query"]:
327
  yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
328
  else:
329
  yield "⚠️ Không thể xác định chức năng phù hợp."
330
+
331
  # Xử lý lời gọi hàm và sinh phản hồi tương ứng
332
  response_iterator = handle_functions(
333
  function_call=function_call,
 
339
  top_k=top_k,
340
  repetition_penalty=repetition_penalty
341
  )
342
+
343
  for response in response_iterator:
344
  yield response
345
 
 
354
  ["Tìm thông tin về Rạn san hô Great Barrier Reef."],
355
  ["Tóm tắt nội dung về trí tuệ nhân tạo."],
356
  ["Phân tích tâm lý của đoạn văn sau: Tôi rất vui khi được gặp bạn hôm nay!"],
357
+ ["Huấn luyện mô hình với dữ liệu mới để cải thiện khả năng hiểu tiếng Việt."], # Ví dụ mới thêm
358
  ]
359
 
360
  # Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
 
401
  examples=EXAMPLES, # Các ví dụ được hiển thị cho người dùng
402
  cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ
403
  title="🤖 OpenGPT-4o Chatbot",
404
+ description="Một trợ lý AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản, phân tích tâm lý và huấn luyện mô hình.",
405
  theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
406
  )
407
 
408
+ # ---------------------------- Huấn Luyện Mô Hình ---------------------------- #
409
 
410
  # Đường dẫn lưu checkpoint
411
  CHECKPOINT_DIR = "./checkpoints"
 
475
  if state.global_step % args.save_steps == 0 and state.global_step != 0:
476
  checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
477
  print(f"Lưu checkpoint tại: {checkpoint_path}")
478
+ trainer = kwargs.get('trainer') # Sử dụng get để tránh KeyError
479
+ if trainer:
480
+ trainer.save_model(checkpoint_path)
481
+ else:
482
+ print("Không thể truy cập 'trainer' từ kwargs.")
483
+ return control
484
 
485
  # Định Nghĩa Hàm Huấn Luyện với Decorator @spaces.GPU
486
+ @spaces.GPU(duration=60, queue=False) # Tăng duration lên 60 giây
487
+ def run_training(steps_per_call=5):
488
  """
489
  Hàm huấn luyện mô hình sử dụng GPU với thời gian hạn chế.
490
+ Huấn luyện 5 bước mỗi lần gọi.
491
  """
492
  # Tải và Cấu Hình Mô Hình với LoRA (GPU)
493
  model = AutoModelForCausalLM.from_pretrained(
 
514
  per_device_train_batch_size=4,
515
  per_device_eval_batch_size=4,
516
  gradient_accumulation_steps=8,
517
+ num_train_epochs=1, # Giới hạn epochs để đảm bảo chỉ huấn luyện 5 bước
518
+ max_steps=steps_per_call, # Đặt max_steps tại đây
519
  learning_rate=3e-4,
520
  weight_decay=0.01,
521
+ logging_steps=1, # Giảm số bước logging để theo dõi thường xuyên hơn
522
  eval_strategy="steps", # Đánh giá sau mỗi vài bước
523
+ eval_steps=steps_per_call, # Đánh giá sau mỗi 5 bước
524
  save_strategy="steps", # Lưu checkpoint sau mỗi vài bước
525
+ save_steps=steps_per_call, # Lưu checkpoint sau mỗi 5 bước
526
  save_total_limit=5, # Giới hạn số lượng checkpoint lưu trữ
527
  fp16=True,
528
  report_to="none",
 
544
  eval_dataset=final_dataset['validation'],
545
  tokenizer=tokenizer,
546
  data_collator=data_collator,
547
+ callbacks=[SaveCheckpointCallback()], # Thêm callback đã sửa đổi
548
  )
549
 
550
  # Kiểm tra nếu có checkpoint
 
560
  trainer.save_model(CHECKPOINT_DIR)
561
  return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
562
 
563
+ # Hàm Huấn Luyện
564
+ def train_model():
565
  """
566
+ Hàm huấn luyện hình, huấn luyện 5 bước mỗi lần đảm bảo lưu checkpoint.
 
 
 
 
567
  """
568
+ # Gọi hàm huấn luyện với steps_per_call=5
569
+ result = run_training(steps_per_call=5)
570
+ yield result
571
+
572
+ # ---------------------------- Giao Diện Gradio ---------------------------- #
 
 
 
 
 
 
 
 
 
 
573
 
574
  # Tạo giao diện chính của Gradio với CSS tùy chỉnh
575
  with gr.Blocks(css="""
 
597
  gr.DuplicateButton(value="Nhân bản Không gian để sử dụng riêng tư", elem_id="duplicate-button") # Nút nhân bản không gian
598
  chat_interface.render() # Hiển thị giao diện trò chuyện
599
 
600
+ # ---------------------------- Khởi Chạy Ứng Dụng ---------------------------- #
601
+
602
  if __name__ == "__main__":
603
+ demo.queue(max_size=30).launch() # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là 30