Spaces:
Sleeping
Sleeping
Commit
·
9002cd7
1
Parent(s):
11c6b76
update functions call
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import os
|
3 |
from threading import Thread
|
4 |
from typing import Iterator, List, Tuple, Dict, Any
|
@@ -6,7 +5,16 @@ from typing import Iterator, List, Tuple, Dict, Any
|
|
6 |
import gradio as gr
|
7 |
import spaces
|
8 |
import torch
|
9 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from bs4 import BeautifulSoup
|
11 |
import requests
|
12 |
import json
|
@@ -184,10 +192,14 @@ def process_query(query: str) -> Dict[str, Any]:
|
|
184 |
general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
|
185 |
summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
|
186 |
sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]
|
|
|
187 |
|
188 |
query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh
|
189 |
-
|
190 |
-
if any(keyword in query_lower for keyword in
|
|
|
|
|
|
|
191 |
function_name = "web_search"
|
192 |
arguments = {"query": query}
|
193 |
elif any(keyword in query_lower for keyword in summarize_keywords):
|
@@ -202,7 +214,7 @@ def process_query(query: str) -> Dict[str, Any]:
|
|
202 |
else:
|
203 |
function_name = "hard_query"
|
204 |
arguments = {"prompt": query}
|
205 |
-
|
206 |
return {
|
207 |
"name": function_name,
|
208 |
"arguments": arguments
|
@@ -214,7 +226,7 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
214 |
"""
|
215 |
function_name = function_call["name"]
|
216 |
arguments = function_call["arguments"]
|
217 |
-
|
218 |
if function_name == "web_search":
|
219 |
query = arguments["query"]
|
220 |
yield "🔍 Đang thực hiện tìm kiếm trên web..."
|
@@ -226,10 +238,10 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
226 |
web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
|
227 |
if not web_summary:
|
228 |
web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
|
229 |
-
|
230 |
# Trả về kết quả tìm kiếm cho người dùng
|
231 |
yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
|
232 |
-
|
233 |
elif function_name == "summarize_query":
|
234 |
# Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
|
235 |
query = arguments["prompt"]
|
@@ -247,13 +259,21 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
247 |
yield "📝 Đang tóm tắt thông tin..."
|
248 |
summary = summarize_text(combined_text)
|
249 |
yield "📄 **Tóm tắt:**\n" + summary
|
250 |
-
|
251 |
elif function_name == "sentiment_analysis":
|
252 |
prompt_text = arguments["prompt"]
|
253 |
yield "📊 Đang phân tích tâm lý..."
|
254 |
sentiment = analyze_sentiment(prompt_text)
|
255 |
yield sentiment
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
elif function_name in ["general_query", "hard_query"]:
|
258 |
prompt_text = arguments["prompt"]
|
259 |
yield "🤖 Đang tạo phản hồi..."
|
@@ -269,7 +289,7 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
269 |
)
|
270 |
for response in response_generator:
|
271 |
yield response
|
272 |
-
|
273 |
else:
|
274 |
yield "⚠️ Lời gọi hàm không được nhận dạng."
|
275 |
|
@@ -291,23 +311,23 @@ def generate(
|
|
291 |
# Thông báo về việc phân tích đầu vào
|
292 |
yield "🔍 Đang phân tích truy vấn của bạn..."
|
293 |
|
294 |
-
|
295 |
# Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
|
296 |
function_call = process_query(message)
|
297 |
-
|
298 |
# Thông báo về hàm được chọn
|
299 |
if function_call["name"] == "web_search":
|
300 |
yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
|
301 |
elif function_call["name"] == "summarize_query":
|
302 |
yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
|
303 |
elif function_call["name"] == "sentiment_analysis":
|
304 |
-
continuous_training(total_steps=300, steps_per_call=50)
|
305 |
yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
|
|
|
|
|
306 |
elif function_call["name"] in ["general_query", "hard_query"]:
|
307 |
yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
|
308 |
else:
|
309 |
yield "⚠️ Không thể xác định chức năng phù hợp."
|
310 |
-
|
311 |
# Xử lý lời gọi hàm và sinh phản hồi tương ứng
|
312 |
response_iterator = handle_functions(
|
313 |
function_call=function_call,
|
@@ -319,7 +339,7 @@ def generate(
|
|
319 |
top_k=top_k,
|
320 |
repetition_penalty=repetition_penalty
|
321 |
)
|
322 |
-
|
323 |
for response in response_iterator:
|
324 |
yield response
|
325 |
|
@@ -334,6 +354,7 @@ EXAMPLES = [
|
|
334 |
["Tìm thông tin về Rạn san hô Great Barrier Reef."],
|
335 |
["Tóm tắt nội dung về trí tuệ nhân tạo."],
|
336 |
["Phân tích tâm lý của đoạn văn sau: Tôi rất vui khi được gặp bạn hôm nay!"],
|
|
|
337 |
]
|
338 |
|
339 |
# Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
|
@@ -380,10 +401,11 @@ chat_interface = gr.ChatInterface(
|
|
380 |
examples=EXAMPLES, # Các ví dụ được hiển thị cho người dùng
|
381 |
cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ
|
382 |
title="🤖 OpenGPT-4o Chatbot",
|
383 |
-
description="Một trợ lý AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản
|
384 |
theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
|
385 |
)
|
386 |
|
|
|
387 |
|
388 |
# Đường dẫn lưu checkpoint
|
389 |
CHECKPOINT_DIR = "./checkpoints"
|
@@ -453,15 +475,19 @@ class SaveCheckpointCallback(TrainerCallback):
|
|
453 |
if state.global_step % args.save_steps == 0 and state.global_step != 0:
|
454 |
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
455 |
print(f"Lưu checkpoint tại: {checkpoint_path}")
|
456 |
-
trainer = kwargs
|
457 |
-
trainer
|
458 |
-
|
|
|
|
|
|
|
459 |
|
460 |
# Định Nghĩa Hàm Huấn Luyện với Decorator @spaces.GPU
|
461 |
-
@spaces.GPU(duration=
|
462 |
-
def run_training():
|
463 |
"""
|
464 |
Hàm huấn luyện mô hình sử dụng GPU với thời gian hạn chế.
|
|
|
465 |
"""
|
466 |
# Tải và Cấu Hình Mô Hình với LoRA (GPU)
|
467 |
model = AutoModelForCausalLM.from_pretrained(
|
@@ -488,15 +514,15 @@ def run_training():
|
|
488 |
per_device_train_batch_size=4,
|
489 |
per_device_eval_batch_size=4,
|
490 |
gradient_accumulation_steps=8,
|
491 |
-
num_train_epochs=
|
492 |
-
max_steps=
|
493 |
learning_rate=3e-4,
|
494 |
weight_decay=0.01,
|
495 |
-
logging_steps=
|
496 |
eval_strategy="steps", # Đánh giá sau mỗi vài bước
|
497 |
-
eval_steps=
|
498 |
save_strategy="steps", # Lưu checkpoint sau mỗi vài bước
|
499 |
-
save_steps=
|
500 |
save_total_limit=5, # Giới hạn số lượng checkpoint lưu trữ
|
501 |
fp16=True,
|
502 |
report_to="none",
|
@@ -518,7 +544,7 @@ def run_training():
|
|
518 |
eval_dataset=final_dataset['validation'],
|
519 |
tokenizer=tokenizer,
|
520 |
data_collator=data_collator,
|
521 |
-
callbacks=[SaveCheckpointCallback()], # Thêm callback
|
522 |
)
|
523 |
|
524 |
# Kiểm tra nếu có checkpoint
|
@@ -534,30 +560,16 @@ def run_training():
|
|
534 |
trainer.save_model(CHECKPOINT_DIR)
|
535 |
return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
|
536 |
|
537 |
-
# Hàm
|
538 |
-
def
|
539 |
"""
|
540 |
-
Hàm
|
541 |
-
|
542 |
-
Args:
|
543 |
-
total_steps (int): Tổng số bước huấn luyện mong muốn.
|
544 |
-
steps_per_call (int): Số bước huấn luyện mỗi lần gọi hàm.
|
545 |
"""
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
steps_done += steps_per_call
|
552 |
-
print(f"Đã huấn luyện {steps_done} / {total_steps} bước.")
|
553 |
-
|
554 |
-
# Kiểm tra nếu đã đạt số bước mong muốn
|
555 |
-
if steps_done >= total_steps:
|
556 |
-
print("Đã hoàn thành toàn bộ quá trình huấn luyện.")
|
557 |
-
break
|
558 |
-
|
559 |
-
# Chờ một khoảng thời gian trước khi gọi lại (tùy thuộc vào yêu cầu của hệ thống)
|
560 |
-
time.sleep(2) # Thời gian chờ có thể điều chỉnh
|
561 |
|
562 |
# Tạo giao diện chính của Gradio với CSS tùy chỉnh
|
563 |
with gr.Blocks(css="""
|
@@ -585,5 +597,7 @@ with gr.Blocks(css="""
|
|
585 |
gr.DuplicateButton(value="Nhân bản Không gian để sử dụng riêng tư", elem_id="duplicate-button") # Nút nhân bản không gian
|
586 |
chat_interface.render() # Hiển thị giao diện trò chuyện
|
587 |
|
|
|
|
|
588 |
if __name__ == "__main__":
|
589 |
-
demo.queue(max_size=30).launch() # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là
|
|
|
|
|
1 |
import os
|
2 |
from threading import Thread
|
3 |
from typing import Iterator, List, Tuple, Dict, Any
|
|
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
import torch
|
8 |
+
from transformers import (
|
9 |
+
TrainingArguments,
|
10 |
+
Trainer,
|
11 |
+
DataCollatorForLanguageModeling,
|
12 |
+
TrainerCallback,
|
13 |
+
AutoModelForCausalLM,
|
14 |
+
AutoTokenizer,
|
15 |
+
TextIteratorStreamer,
|
16 |
+
pipeline
|
17 |
+
)
|
18 |
from bs4 import BeautifulSoup
|
19 |
import requests
|
20 |
import json
|
|
|
192 |
general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
|
193 |
summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
|
194 |
sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]
|
195 |
+
train_keywords = ["huấn luyện", "train", "fine-tune", "tinh chỉnh"]
|
196 |
|
197 |
query_lower = query.lower() # Chuyển truy vấn thành chữ thường để so sánh
|
198 |
+
|
199 |
+
if any(keyword in query_lower for keyword in train_keywords):
|
200 |
+
function_name = "train_model"
|
201 |
+
arguments = {"prompt": query}
|
202 |
+
elif any(keyword in query_lower for keyword in web_search_keywords):
|
203 |
function_name = "web_search"
|
204 |
arguments = {"query": query}
|
205 |
elif any(keyword in query_lower for keyword in summarize_keywords):
|
|
|
214 |
else:
|
215 |
function_name = "hard_query"
|
216 |
arguments = {"prompt": query}
|
217 |
+
|
218 |
return {
|
219 |
"name": function_name,
|
220 |
"arguments": arguments
|
|
|
226 |
"""
|
227 |
function_name = function_call["name"]
|
228 |
arguments = function_call["arguments"]
|
229 |
+
|
230 |
if function_name == "web_search":
|
231 |
query = arguments["query"]
|
232 |
yield "🔍 Đang thực hiện tìm kiếm trên web..."
|
|
|
238 |
web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
|
239 |
if not web_summary:
|
240 |
web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
|
241 |
+
|
242 |
# Trả về kết quả tìm kiếm cho người dùng
|
243 |
yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
|
244 |
+
|
245 |
elif function_name == "summarize_query":
|
246 |
# Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
|
247 |
query = arguments["prompt"]
|
|
|
259 |
yield "📝 Đang tóm tắt thông tin..."
|
260 |
summary = summarize_text(combined_text)
|
261 |
yield "📄 **Tóm tắt:**\n" + summary
|
262 |
+
|
263 |
elif function_name == "sentiment_analysis":
|
264 |
prompt_text = arguments["prompt"]
|
265 |
yield "📊 Đang phân tích tâm lý..."
|
266 |
sentiment = analyze_sentiment(prompt_text)
|
267 |
yield sentiment
|
268 |
+
|
269 |
+
elif function_name == "train_model":
|
270 |
+
prompt_text = arguments["prompt"]
|
271 |
+
yield "🛠️ Đang bắt đầu quá trình huấn luyện..."
|
272 |
+
# Gọi hàm huấn luyện
|
273 |
+
training_iterator = train_model()
|
274 |
+
for response in training_iterator:
|
275 |
+
yield response
|
276 |
+
|
277 |
elif function_name in ["general_query", "hard_query"]:
|
278 |
prompt_text = arguments["prompt"]
|
279 |
yield "🤖 Đang tạo phản hồi..."
|
|
|
289 |
)
|
290 |
for response in response_generator:
|
291 |
yield response
|
292 |
+
|
293 |
else:
|
294 |
yield "⚠️ Lời gọi hàm không được nhận dạng."
|
295 |
|
|
|
311 |
# Thông báo về việc phân tích đầu vào
|
312 |
yield "🔍 Đang phân tích truy vấn của bạn..."
|
313 |
|
|
|
314 |
# Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
|
315 |
function_call = process_query(message)
|
316 |
+
|
317 |
# Thông báo về hàm được chọn
|
318 |
if function_call["name"] == "web_search":
|
319 |
yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
|
320 |
elif function_call["name"] == "summarize_query":
|
321 |
yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
|
322 |
elif function_call["name"] == "sentiment_analysis":
|
|
|
323 |
yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
|
324 |
+
elif function_call["name"] == "train_model":
|
325 |
+
yield "🛠️ Đã chọn chức năng: Huấn luyện mô hình."
|
326 |
elif function_call["name"] in ["general_query", "hard_query"]:
|
327 |
yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
|
328 |
else:
|
329 |
yield "⚠️ Không thể xác định chức năng phù hợp."
|
330 |
+
|
331 |
# Xử lý lời gọi hàm và sinh phản hồi tương ứng
|
332 |
response_iterator = handle_functions(
|
333 |
function_call=function_call,
|
|
|
339 |
top_k=top_k,
|
340 |
repetition_penalty=repetition_penalty
|
341 |
)
|
342 |
+
|
343 |
for response in response_iterator:
|
344 |
yield response
|
345 |
|
|
|
354 |
["Tìm thông tin về Rạn san hô Great Barrier Reef."],
|
355 |
["Tóm tắt nội dung về trí tuệ nhân tạo."],
|
356 |
["Phân tích tâm lý của đoạn văn sau: Tôi rất vui khi được gặp bạn hôm nay!"],
|
357 |
+
["Huấn luyện mô hình với dữ liệu mới để cải thiện khả năng hiểu tiếng Việt."], # Ví dụ mới thêm
|
358 |
]
|
359 |
|
360 |
# Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
|
|
|
401 |
examples=EXAMPLES, # Các ví dụ được hiển thị cho người dùng
|
402 |
cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ
|
403 |
title="🤖 OpenGPT-4o Chatbot",
|
404 |
+
description="Một trợ lý AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản, phân tích tâm lý và huấn luyện mô hình.",
|
405 |
theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
|
406 |
)
|
407 |
|
408 |
+
# ---------------------------- Huấn Luyện Mô Hình ---------------------------- #
|
409 |
|
410 |
# Đường dẫn lưu checkpoint
|
411 |
CHECKPOINT_DIR = "./checkpoints"
|
|
|
475 |
if state.global_step % args.save_steps == 0 and state.global_step != 0:
|
476 |
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
477 |
print(f"Lưu checkpoint tại: {checkpoint_path}")
|
478 |
+
trainer = kwargs.get('trainer') # Sử dụng get để tránh KeyError
|
479 |
+
if trainer:
|
480 |
+
trainer.save_model(checkpoint_path)
|
481 |
+
else:
|
482 |
+
print("Không thể truy cập 'trainer' từ kwargs.")
|
483 |
+
return control
|
484 |
|
485 |
# Định Nghĩa Hàm Huấn Luyện với Decorator @spaces.GPU
|
486 |
+
@spaces.GPU(duration=60, queue=False) # Tăng duration lên 60 giây
|
487 |
+
def run_training(steps_per_call=5):
|
488 |
"""
|
489 |
Hàm huấn luyện mô hình sử dụng GPU với thời gian hạn chế.
|
490 |
+
Huấn luyện 5 bước mỗi lần gọi.
|
491 |
"""
|
492 |
# Tải và Cấu Hình Mô Hình với LoRA (GPU)
|
493 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
514 |
per_device_train_batch_size=4,
|
515 |
per_device_eval_batch_size=4,
|
516 |
gradient_accumulation_steps=8,
|
517 |
+
num_train_epochs=1, # Giới hạn epochs để đảm bảo chỉ huấn luyện 5 bước
|
518 |
+
max_steps=steps_per_call, # Đặt max_steps tại đây
|
519 |
learning_rate=3e-4,
|
520 |
weight_decay=0.01,
|
521 |
+
logging_steps=1, # Giảm số bước logging để theo dõi thường xuyên hơn
|
522 |
eval_strategy="steps", # Đánh giá sau mỗi vài bước
|
523 |
+
eval_steps=steps_per_call, # Đánh giá sau mỗi 5 bước
|
524 |
save_strategy="steps", # Lưu checkpoint sau mỗi vài bước
|
525 |
+
save_steps=steps_per_call, # Lưu checkpoint sau mỗi 5 bước
|
526 |
save_total_limit=5, # Giới hạn số lượng checkpoint lưu trữ
|
527 |
fp16=True,
|
528 |
report_to="none",
|
|
|
544 |
eval_dataset=final_dataset['validation'],
|
545 |
tokenizer=tokenizer,
|
546 |
data_collator=data_collator,
|
547 |
+
callbacks=[SaveCheckpointCallback()], # Thêm callback đã sửa đổi
|
548 |
)
|
549 |
|
550 |
# Kiểm tra nếu có checkpoint
|
|
|
560 |
trainer.save_model(CHECKPOINT_DIR)
|
561 |
return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
|
562 |
|
563 |
+
# Hàm Huấn Luyện
|
564 |
+
def train_model():
|
565 |
"""
|
566 |
+
Hàm huấn luyện mô hình, huấn luyện 5 bước mỗi lần và đảm bảo lưu checkpoint.
|
|
|
|
|
|
|
|
|
567 |
"""
|
568 |
+
# Gọi hàm huấn luyện với steps_per_call=5
|
569 |
+
result = run_training(steps_per_call=5)
|
570 |
+
yield result
|
571 |
+
|
572 |
+
# ---------------------------- Giao Diện Gradio ---------------------------- #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
# Tạo giao diện chính của Gradio với CSS tùy chỉnh
|
575 |
with gr.Blocks(css="""
|
|
|
597 |
gr.DuplicateButton(value="Nhân bản Không gian để sử dụng riêng tư", elem_id="duplicate-button") # Nút nhân bản không gian
|
598 |
chat_interface.render() # Hiển thị giao diện trò chuyện
|
599 |
|
600 |
+
# ---------------------------- Khởi Chạy Ứng Dụng ---------------------------- #
|
601 |
+
|
602 |
if __name__ == "__main__":
|
603 |
+
demo.queue(max_size=30).launch() # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là 30
|