LLAMA3.2-GRop / app.py
hoduyquocbao's picture
fix errors
b63ef0b
raw
history blame
28.2 kB
import os
from threading import Thread
from typing import Iterator, List, Tuple, Dict, Any
import gradio as gr
import torch
from transformers import (
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
TrainerCallback,
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
pipeline
)
from bs4 import BeautifulSoup
import requests
import json
from functools import lru_cache
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
import time
# ---------------------------- Cấu Hình ---------------------------- #
# Vô hiệu hóa cảnh báo tokenizers_parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"
DESCRIPTION = """\
# Llama 3.2 3B Instruct với Chức Năng Nâng Cao
Llama 3.2 3B là phiên bản mới nhất của Meta về các mô hình ngôn ngữ mở.
Demo này giới thiệu [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), được tinh chỉnh để theo dõi hướng dẫn.
Để biết thêm chi tiết, vui lòng xem [bài viết của chúng tôi](https://huggingface.co/blog/llama32).
"""
MAX_MAX_NEW_TOKENS = 2048 # Số token tối đa có thể tạo ra
DEFAULT_MAX_NEW_TOKENS = 1024 # Số token tạo ra mặc định
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000")) # Độ dài token tối đa cho đầu vào
# Xác định thiết bị sử dụng (GPU nếu có, ngược lại CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "meta-llama/Llama-3.2-3B-Instruct" # ID mô hình
# Tải tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Tải mô hình cho huấn luyện và áp dụng LoRA
pretrained = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
load_in_8bit=False
)
# Cấu hình LoRA
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
lora_dropout=0.1,
bias="none",
)
# Áp dụng LoRA vào mô hình
pretrained_model = get_peft_model(pretrained, lora_config)
pretrained_model.print_trainable_parameters()
# Đảm bảo mô hình ở chế độ huấn luyện
pretrained_model.train()
# Khởi tạo pipeline phân tích tâm lý trên GPU nếu có
sentiment_pipeline = pipeline(
"sentiment-analysis",
model="nlptown/bert-base-multilingual-uncased-sentiment",
device=0 if torch.cuda.is_available() else -1
)
# ---------------------------- Định Nghĩa Hàm ---------------------------- #
@lru_cache(maxsize=128)
def extract_text_from_webpage(html_content: str) -> str:
"""Trích xuất văn bản hiển thị từ nội dung HTML sử dụng BeautifulSoup."""
soup = BeautifulSoup(html_content, "html.parser")
# Loại bỏ các thẻ không hiển thị như script, style, header, footer, nav, form, svg
for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
tag.extract()
# Trích xuất văn bản hiển thị, tách bằng dấu cách và loại bỏ khoảng trắng thừa
visible_text = soup.get_text(separator=' ', strip=True)
return visible_text
def search(query: str) -> List[Dict[str, Any]]:
"""Thực hiện tìm kiếm trên Google và trả về kết quả."""
term = query
all_results = []
max_chars_per_page = 8000 # Số ký tự tối đa mỗi trang
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
}
with requests.Session() as session:
try:
resp = session.get(
url="https://www.google.com/search",
headers=headers,
params={"q": term, "num": 4}, # Tìm kiếm với 4 kết quả mỗi trang
timeout=5,
verify=False, # Bỏ qua xác minh SSL
)
resp.raise_for_status()
soup = BeautifulSoup(resp.text, "html.parser")
result_blocks = soup.find_all("div", attrs={"class": "g"}) # Tìm tất cả các khối kết quả
for result in result_blocks:
link_tag = result.find("a", href=True) # Tìm thẻ liên kết
if link_tag and 'href' in link_tag.attrs:
link = link_tag["href"]
try:
webpage = session.get(
link,
headers=headers,
timeout=5,
verify=False
)
webpage.raise_for_status()
visible_text = extract_text_from_webpage(webpage.text)
if len(visible_text) > max_chars_per_page:
visible_text = visible_text[:max_chars_per_page] # Cắt văn bản nếu quá dài
all_results.append({"link": link, "text": visible_text})
except requests.exceptions.RequestException:
all_results.append({"link": link, "text": "Không thể lấy nội dung."})
except requests.exceptions.RequestException as e:
all_results.append({"link": "N/A", "text": "Không thể thực hiện tìm kiếm."})
return all_results
def summarize_text(text: str, max_length: int = 150) -> str:
"""Tóm tắt văn bản sử dụng mô hình Llama."""
conversation = [
{"role": "user", "content": f"Hãy tóm tắt đoạn văn sau: {text}"}
]
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
input_ids = input_ids.to(device)
summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
summary_kwargs = {
"input_ids": input_ids,
"streamer": summary_streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": 0.95,
"temperature": 0.7,
}
t = Thread(target=pretrained_model.generate, kwargs=summary_kwargs)
t.start()
summary = ""
for new_text in summary_streamer:
summary += new_text
return summary
def analyze_sentiment(text: str) -> str:
"""Phân tích tâm lý của văn bản sử dụng mô hình."""
result = sentiment_pipeline(text)
sentiment = result[0]['label']
score = result[0]['score']
return f"🟢 **Tâm lý**: {sentiment} (Điểm: {score:.2f})"
def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
"""
Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming.
"""
# Xây dựng lịch sử cuộc trò chuyện
conversation = []
for user, assistant in chat_history:
conversation.extend([
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
])
conversation.append({"role": "user", "content": prompt}) # Thêm tin nhắn của người dùng
# Chuẩn bị input_ids từ tokenizer
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # Cắt input nếu quá dài
gr.Warning(f"Đã cắt bỏ phần cuộc trò chuyện vì vượt quá {MAX_INPUT_TOKEN_LENGTH} token.")
input_ids = input_ids.to(device) # Di chuyển input tới thiết bị
# Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"num_beams": 1,
"repetition_penalty": repetition_penalty,
}
t = Thread(target=pretrained_model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản
t.start()
# Stream văn bản được tạo ra
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
@lru_cache(maxsize=128)
def process_query(query: str) -> Dict[str, Any]:
"""
Xác định hàm nào sẽ được gọi dựa trên truy vấn của người dùng.
"""
# Định nghĩa các từ khóa hoặc mẫu để xác định hàm
web_search_keywords = ["tìm kiếm", "tìm", "tra cứu", "google", "lookup"]
general_query_keywords = ["giải thích", "mô tả", "nói cho tôi biết về", "cái gì là", "cách nào"]
summarize_keywords = ["tóm tắt", "tóm lại", "khái quát", "ngắn gọn"]
sentiment_keywords = ["cảm xúc", "tâm trạng", "tâm lý", "phân tích cảm xúc"]
train_keywords = ["huấn luyện"]
query_lower = query.lower()
if any(keyword in query_lower for keyword in web_search_keywords):
function_name = "web_search"
arguments = {"query": query}
elif any(keyword in query_lower for keyword in summarize_keywords):
function_name = "summarize_query"
arguments = {"prompt": query}
elif any(keyword in query_lower for keyword in sentiment_keywords):
function_name = "sentiment_analysis"
arguments = {"prompt": query}
elif any(keyword in query_lower for keyword in general_query_keywords):
function_name = "general_query"
arguments = {"prompt": query}
elif any(keyword in query_lower for keyword in train_keywords):
function_name = "train_model"
arguments = {"prompt": query}
else:
function_name = "hard_query"
arguments = {"prompt": query}
return {
"name": function_name,
"arguments": arguments
}
def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
"""
Thực thi hàm phù hợp dựa trên lời gọi hàm.
"""
function_name = function_call["name"]
arguments = function_call["arguments"]
if function_name == "web_search":
query = arguments["query"]
yield "🔍 Đang thực hiện tìm kiếm trên web..."
web_results = search(query)
if not web_results:
yield "⚠️ Không tìm thấy kết quả."
return
# Tóm tắt kết quả tìm kiếm
web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
if not web_summary:
web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
# Trả về kết quả tìm kiếm cho người dùng
yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
elif function_name == "summarize_query":
# Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
query = arguments["prompt"]
yield "🔍 Đang thực hiện tìm kiếm để tóm tắt..."
web_results = search(query)
if not web_results:
yield "⚠️ Không tìm thấy kết quả để tóm tắt."
return
# Lấy nội dung từ kết quả tìm kiếm để tóm tắt
combined_text = ' '.join([res['text'] for res in web_results if res['text'] != "Không thể lấy nội dung."])
if not combined_text:
yield "⚠️ Không có nội dung để tóm tắt."
return
# Tóm tắt nội dung đã lấy
yield "📝 Đang tóm tắt thông tin..."
summary = summarize_text(combined_text)
yield "📄 **Tóm tắt:**\n" + summary
elif function_name == "sentiment_analysis":
prompt_text = arguments["prompt"]
yield "📊 Đang phân tích tâm lý..."
sentiment = analyze_sentiment(prompt_text)
yield sentiment
elif function_name == "train_model":
prompt_text = arguments["prompt"]
yield "📊 Đang huấn luyện mô hình..."
training_result = run_training()
yield training_result
elif function_name in ["general_query", "hard_query"]:
prompt_text = arguments["prompt"]
yield "🤖 Đang tạo phản hồi..."
# Tạo phản hồi sử dụng mô hình Llama
response_generator = generate_response(
prompt=prompt_text,
chat_history=chat_history,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty
)
for response in response_generator:
yield response
else:
yield "⚠️ Lời gọi hàm không được nhận dạng."
# ---------------------------- Huấn luyện ---------------------------- #
# Đường dẫn lưu checkpoint
CHECKPOINT_DIR = "./checkpoints"
if not os.path.exists(CHECKPOINT_DIR):
os.makedirs(CHECKPOINT_DIR)
# Tải Dataset (CPU)
dataset = load_dataset('vntc/wiki-mini-corpus')
# Chia Dataset thành train và validation (CPU)
split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset['train']
validation_dataset = split_dataset['test']
# Tiền Xử Lý Văn Bản (CPU)
def preprocess_function(examples):
passages = [passage.lower().strip() for passage in examples['passage']]
return {'passage': passages}
processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
# Đảm bảo tokenizer có pad_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
return tokenizer(
examples['passage'],
padding='max_length',
truncation=True,
max_length=512,
return_tensors="pt"
)
tokenized_train = processed_train.map(tokenize_function, batched=True)
tokenized_validation = processed_validation.map(tokenize_function, batched=True)
# Thêm trường 'labels' (CPU)
def add_labels(examples):
examples['labels'] = examples['input_ids'].copy()
return examples
tokenized_train = tokenized_train.map(add_labels, batched=True)
tokenized_validation = tokenized_validation.map(add_labels, batched=True)
# Loại bỏ các cột không cần thiết (CPU)
tokenized_train = tokenized_train.remove_columns(['passage'])
tokenized_validation = tokenized_validation.remove_columns(['passage'])
# Định dạng dữ liệu cho PyTorch (CPU)
tokenized_train.set_format('torch')
tokenized_validation.set_format('torch')
# Tạo DatasetDict (CPU)
final_dataset = {
'train': tokenized_train,
'validation': tokenized_validation
}
# Định Nghĩa TrainerCallback để Lưu Checkpoint Nhanh Hơn
class SaveCheckpointCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
print(f"Lưu checkpoint tại: {checkpoint_path}")
kwargs['trainer'].save_model(checkpoint_path)
return control # Trả về đối tượng control hiện tại
# Định Nghĩa TrainerCallback để Xử Lý Kết Thúc Huấn Luyện
class PrintCallback(TrainerCallback):
def on_train_begin(self, args, state, control, **kwargs):
print("Bắt đầu quá trình huấn luyện...")
def on_train_end(self, args, state, control, **kwargs):
print("Quá trình huấn luyện đã kết thúc.")
# Data Collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False, # Vì bạn đang thực hiện Causal LM
pad_to_multiple_of=8
)
def get_step_done() -> int:
"""
Lấy số bước huấn luyện đã hoàn thành từ checkpoint mới nhất trong thư mục lưu trữ.
Returns:
int: Số bước đã hoàn thành. Trả về 0 nếu không tìm thấy checkpoint.
"""
checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')]
if not checkpoints:
return 0
try:
# Tìm checkpoint mới nhất dựa trên số bước
latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[1]))
step_done = int(latest_checkpoint.split('-')[1])
return step_done
except (IndexError, ValueError) as e:
print(f"Lỗi khi phân tích tên checkpoint: {e}")
return 0
# ---------------------------- Định Nghĩa Huấn Luyện ---------------------------- #
@gradio.GPU # Sử dụng decorator phù hợp nếu cần
def run_training() -> str:
"""
Hàm huấn luyện mô hình sử dụng GPU với thời gian hạn chế.
Returns:
str: Thông báo kết quả huấn luyện.
"""
# Cấu Hình TrainingArguments (GPU)
training_args = TrainingArguments(
output_dir=CHECKPOINT_DIR,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=8,
num_train_epochs=3,
max_steps=300, # Đặt tổng số bước huấn luyện
learning_rate=3e-4,
weight_decay=0.01,
logging_steps=1, # Ghi log sau mỗi bước
eval_strategy="steps", # Đánh giá sau mỗi vài bước
eval_steps=5, # Đánh giá sau mỗi 5 bước
save_strategy="steps", # Lưu checkpoint sau mỗi vài bước
save_steps=5, # Lưu checkpoint sau mỗi 5 bước
save_total_limit=5, # Giới hạn số lượng checkpoint lưu trữ
fp16=True, # Kích hoạt huấn luyện hỗn hợp độ chính xác
report_to="none",
load_best_model_at_end=False, # Tắt load best model để tránh xung đột
)
# Tạo Trainer (GPU)
trainer = Trainer(
model=pretrained_model,
args=training_args,
train_dataset=final_dataset['train'],
eval_dataset=final_dataset['validation'],
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[SaveCheckpointCallback(), PrintCallback()], # Thêm callback
)
# Kiểm tra nếu có checkpoint
steps_done = get_step_done()
if steps_done > 0:
# Xác định checkpoint mới nhất dựa trên số bước
latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
if os.path.exists(latest_checkpoint):
print(f"Đang tiếp tục huấn luyện từ checkpoint: {latest_checkpoint}")
trainer.train(resume_from_checkpoint=latest_checkpoint)
else:
print(f"Checkpoint {latest_checkpoint} không tồn tại. Bắt đầu huấn luyện từ đầu.")
trainer.train()
else:
trainer.train()
# Lưu checkpoint sau khi huấn luyện
trainer.save_model(CHECKPOINT_DIR)
return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
# Hàm Tự Động Hóa Việc Gọi Lặp Lại Hàm Huấn Luyện
@gradio.GPU
def continuous_training(total_steps=300, steps_per_call=50):
"""
Hàm tự động gọi lại `run_training` để hoàn thành quá trình huấn luyện.
Args:
total_steps (int): Tổng số bước huấn luyện mong muốn.
steps_per_call (int): Số bước huấn luyện mỗi lần gọi hàm.
"""
steps_done = get_step_done()
while steps_done < total_steps:
remaining_steps = total_steps - steps_done
current_steps = min(steps_per_call, remaining_steps)
print(f"Bắt đầu huấn luyện cho {current_steps} bước.")
# Cập nhật TrainingArguments để huấn luyện cho current_steps bước
training_args = TrainingArguments(
output_dir=CHECKPOINT_DIR,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=8,
num_train_epochs=1, # Huấn luyện trong một epoch
max_steps=current_steps,
learning_rate=3e-4,
weight_decay=0.01,
logging_steps=10,
eval_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=50,
save_total_limit=5,
fp16=True,
report_to="none",
load_best_model_at_end=False,
)
# Tạo Trainer với TrainingArguments mới
trainer = Trainer(
model=pretrained_model,
args=training_args,
train_dataset=final_dataset['train'],
eval_dataset=final_dataset['validation'],
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[SaveCheckpointCallback(), PrintCallback()],
)
# Tiếp tục huấn luyện từ checkpoint hiện tại
if steps_done > 0:
latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
if os.path.exists(latest_checkpoint):
print(f"Đang tiếp tục huấn luyện từ checkpoint: {latest_checkpoint}")
trainer.train(resume_from_checkpoint=latest_checkpoint)
else:
print(f"Checkpoint {latest_checkpoint} không tồn tại. Bắt đầu huấn luyện từ đầu.")
trainer.train()
else:
trainer.train()
steps_done = get_step_done()
print(f"Đã huấn luyện {steps_done} / {total_steps} bước.")
# Kiểm tra nếu đã đạt số bước mong muốn
if steps_done >= total_steps:
print("Đã hoàn thành toàn bộ quá trình huấn luyện.")
break
# Chờ một khoảng thời gian trước khi gọi lại (tùy thuộc vào yêu cầu của hệ thống)
time.sleep(2) # Thời gian chờ có thể điều chỉnh
# ---------------------------- Giao Diện Gradio ---------------------------- #
@gradio.GPU
def generate(
message: str,
chat_history: List[Tuple[str, str]],
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
"""
Hàm chính để xử lý đầu vào của người dùng và tạo phản hồi.
"""
# Thông báo về việc phân tích đầu vào
yield "🔍 Đang phân tích truy vấn của bạn..."
# Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
function_call = process_query(message)
# Thông báo về hàm được chọn
if function_call["name"] == "web_search":
yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
elif function_call["name"] == "summarize_query":
yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
elif function_call["name"] == "sentiment_analysis":
yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
elif function_call["name"] in ["general_query", "hard_query"]:
yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
elif function_call["name"] == "train_model":
yield "🛠️ Đã chọn chức năng: Huấn luyện mô hình."
else:
yield "⚠️ Không thể xác định chức năng phù hợp."
# Xử lý lời gọi hàm và sinh phản hồi tương ứng
response_iterator = handle_functions(
function_call=function_call,
prompt=message,
chat_history=chat_history,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty
)
for response in response_iterator:
yield response
# Định nghĩa các ví dụ để hướng dẫn người dùng
EXAMPLES = [
["Xin chào! Bạn khỏe không?"],
["Bạn có thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"],
["Giải thích cốt truyện của Cô bé Lọ Lem trong một câu."],
["Một người đàn ông cần bao nhiêu giờ để ăn một chiếc máy bay trực thăng?"],
["Viết một bài báo 100 từ về 'Lợi ích của mã nguồn mở trong nghiên cứu AI'"],
["Tìm và cung cấp cho tôi tin tức mới nhất về năng lượng tái tạo."],
["Tìm thông tin về Rạn san hô Great Barrier Reef."],
["Tóm tắt nội dung về trí tuệ nhân tạo."],
["Phân tích tâm lý của đoạn văn sau: Tôi rất vui khi được gặp bạn hôm nay!"],
["Huấn luyện mô hình!"],
]
# Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
chat_interface = gr.ChatInterface(
fn=generate, # Hàm được gọi khi có tương tác từ người dùng
additional_inputs=[
gr.Slider(
label="Số token mới tối đa",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Nhiệt độ",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Hình phạt sự lặp lại",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None, # Không có nút dừng
examples=EXAMPLES, # Các ví dụ được hiển thị cho người dùng
cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ
title="🤖 OpenGPT-4o Chatbot",
description="Một trợ lý AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản và phân tích tâm lý.",
theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
)
# Tạo giao diện chính của Gradio với CSS tùy chỉnh
with gr.Blocks(css="""
.gradio-container {
background-color: #f0f2f5; /* Màu nền nhẹ nhàng */
}
.gradio-container h1 {
color: #4a90e2; /* Màu xanh dương cho tiêu đề */
}
.gradio-container .gr-button {
background-color: #4a90e2; /* Màu xanh dương cho nút */
color: white; /* Màu chữ trắng trên nút */
}
.gradio-container .gr-slider__label {
color: #333333; /* Màu chữ đen cho nhãn slider */
}
.gradio-container .gr-chatbot {
border: 2px solid #4a90e2; /* Viền xanh dương cho chatbot */
border-radius: 10px; /* Bo góc viền chatbot */
padding: 10px; /* Khoảng cách bên trong chatbot */
background-color: #ffffff; /* Màu nền trắng cho chatbot */
}
""", fill_height=True) as demo:
gr.Markdown(DESCRIPTION) # Hiển thị mô tả
chat_interface.render() # Hiển thị giao diện trò chuyện
if __name__ == "__main__":
demo.queue(max_size=30).launch() # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là 30