Spaces:
Sleeping
Sleeping
Commit
·
c6bca05
1
Parent(s):
0abfd6d
fix errors
Browse files
app.py
CHANGED
@@ -23,56 +23,58 @@ from datasets import load_dataset
|
|
23 |
from peft import LoraConfig, get_peft_model
|
24 |
import time
|
25 |
|
26 |
-
# ----------------------------
|
27 |
|
28 |
DESCRIPTION = """\
|
29 |
-
# Llama 3.2 3B Instruct
|
30 |
|
31 |
-
Llama 3.2 3B
|
32 |
-
|
33 |
-
|
34 |
"""
|
35 |
|
36 |
-
MAX_MAX_NEW_TOKENS = 2048 #
|
37 |
-
DEFAULT_MAX_NEW_TOKENS = 1024 #
|
38 |
-
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000")) #
|
39 |
|
40 |
-
#
|
41 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
42 |
|
43 |
-
model_id = "meta-llama/Llama-3.2-3B-Instruct" #
|
44 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
45 |
model = AutoModelForCausalLM.from_pretrained(
|
46 |
model_id,
|
47 |
device_map="auto",
|
48 |
-
torch_dtype=torch.float16, #
|
49 |
)
|
50 |
model.to(device)
|
51 |
model.eval()
|
52 |
|
53 |
-
#
|
54 |
sentiment_pipeline = pipeline(
|
55 |
"sentiment-analysis",
|
56 |
model="nlptown/bert-base-multilingual-uncased-sentiment",
|
57 |
device=0 if torch.cuda.is_available() else -1
|
58 |
)
|
59 |
|
60 |
-
# ----------------------------
|
61 |
|
62 |
@lru_cache(maxsize=128)
|
63 |
def extract_text_from_webpage(html_content: str) -> str:
|
64 |
-
"""
|
65 |
soup = BeautifulSoup(html_content, "html.parser")
|
|
|
66 |
for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
|
67 |
tag.extract()
|
|
|
68 |
visible_text = soup.get_text(separator=' ', strip=True)
|
69 |
return visible_text
|
70 |
|
71 |
def search(query: str) -> List[Dict[str, Any]]:
|
72 |
-
"""
|
73 |
term = query
|
74 |
all_results = []
|
75 |
-
max_chars_per_page = 8000 #
|
76 |
headers = {
|
77 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
|
78 |
}
|
@@ -81,15 +83,15 @@ def search(query: str) -> List[Dict[str, Any]]:
|
|
81 |
resp = session.get(
|
82 |
url="https://www.google.com/search",
|
83 |
headers=headers,
|
84 |
-
params={"q": term, "num": 4}, # 4
|
85 |
timeout=5,
|
86 |
-
verify=False, #
|
87 |
)
|
88 |
resp.raise_for_status()
|
89 |
soup = BeautifulSoup(resp.text, "html.parser")
|
90 |
-
result_blocks = soup.find_all("div", attrs={"class": "g"})
|
91 |
for result in result_blocks:
|
92 |
-
link_tag = result.find("a", href=True)
|
93 |
if link_tag and 'href' in link_tag.attrs:
|
94 |
link = link_tag["href"]
|
95 |
try:
|
@@ -102,18 +104,18 @@ def search(query: str) -> List[Dict[str, Any]]:
|
|
102 |
webpage.raise_for_status()
|
103 |
visible_text = extract_text_from_webpage(webpage.text)
|
104 |
if len(visible_text) > max_chars_per_page:
|
105 |
-
visible_text = visible_text[:max_chars_per_page]
|
106 |
all_results.append({"link": link, "text": visible_text})
|
107 |
except requests.exceptions.RequestException:
|
108 |
-
all_results.append({"link": link, "text": "
|
109 |
except requests.exceptions.RequestException as e:
|
110 |
-
all_results.append({"link": "N/A", "text": "
|
111 |
return all_results
|
112 |
|
113 |
def summarize_text(text: str, max_length: int = 150) -> str:
|
114 |
-
"""
|
115 |
conversation = [
|
116 |
-
{"role": "user", "content": f"
|
117 |
]
|
118 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
119 |
input_ids = input_ids.to(device)
|
@@ -136,33 +138,33 @@ def summarize_text(text: str, max_length: int = 150) -> str:
|
|
136 |
return summary
|
137 |
|
138 |
def analyze_sentiment(text: str) -> str:
|
139 |
-
"""
|
140 |
result = sentiment_pipeline(text)
|
141 |
sentiment = result[0]['label']
|
142 |
score = result[0]['score']
|
143 |
-
return f"🟢 **
|
144 |
|
145 |
def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
146 |
"""
|
147 |
-
|
148 |
"""
|
149 |
-
#
|
150 |
conversation = []
|
151 |
for user, assistant in chat_history:
|
152 |
conversation.extend([
|
153 |
{"role": "user", "content": user},
|
154 |
{"role": "assistant", "content": assistant},
|
155 |
])
|
156 |
-
conversation.append({"role": "user", "content": prompt}) #
|
157 |
|
158 |
-
#
|
159 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
160 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
161 |
-
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] #
|
162 |
-
gr.Warning(f"
|
163 |
-
input_ids = input_ids.to(device)
|
164 |
|
165 |
-
#
|
166 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
167 |
generate_kwargs = {
|
168 |
"input_ids": input_ids,
|
@@ -175,10 +177,10 @@ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_
|
|
175 |
"num_beams": 1,
|
176 |
"repetition_penalty": repetition_penalty,
|
177 |
}
|
178 |
-
t = Thread(target=model.generate, kwargs=generate_kwargs) #
|
179 |
t.start()
|
180 |
|
181 |
-
# Stream
|
182 |
outputs = []
|
183 |
for text in streamer:
|
184 |
outputs.append(text)
|
@@ -187,9 +189,9 @@ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_
|
|
187 |
@lru_cache(maxsize=128)
|
188 |
def process_query(query: str) -> Dict[str, Any]:
|
189 |
"""
|
190 |
-
|
191 |
"""
|
192 |
-
#
|
193 |
web_search_keywords = ["search", "find", "lookup", "google"]
|
194 |
general_query_keywords = ["explain", "describe", "tell me about", "what is", "how to"]
|
195 |
summarize_keywords = ["summarize", "summarise", "brief", "short"]
|
@@ -224,60 +226,60 @@ def process_query(query: str) -> Dict[str, Any]:
|
|
224 |
|
225 |
def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
226 |
"""
|
227 |
-
|
228 |
"""
|
229 |
function_name = function_call["name"]
|
230 |
arguments = function_call["arguments"]
|
231 |
|
232 |
if function_name == "web_search":
|
233 |
query = arguments["query"]
|
234 |
-
yield "🔍
|
235 |
web_results = search(query)
|
236 |
if not web_results:
|
237 |
-
yield "⚠️
|
238 |
return
|
239 |
-
#
|
240 |
-
web_summary = '\n\n'.join([f"🔗 **
|
241 |
if not web_summary:
|
242 |
-
web_summary = "⚠️
|
243 |
|
244 |
-
#
|
245 |
-
yield "📄 **
|
246 |
|
247 |
elif function_name == "summarize_query":
|
248 |
-
#
|
249 |
query = arguments["prompt"]
|
250 |
-
yield "🔍
|
251 |
web_results = search(query)
|
252 |
if not web_results:
|
253 |
-
yield "⚠️
|
254 |
return
|
255 |
-
#
|
256 |
-
combined_text = ' '.join([res['text'] for res in web_results if res['text'] != "
|
257 |
if not combined_text:
|
258 |
-
yield "⚠️
|
259 |
return
|
260 |
-
#
|
261 |
-
yield "📝
|
262 |
summary = summarize_text(combined_text)
|
263 |
-
yield "📄 **
|
264 |
|
265 |
elif function_name == "sentiment_analysis":
|
266 |
prompt_text = arguments["prompt"]
|
267 |
-
yield "📊
|
268 |
sentiment = analyze_sentiment(prompt_text)
|
269 |
yield sentiment
|
270 |
|
271 |
elif function_name == "train_model":
|
272 |
prompt_text = arguments["prompt"]
|
273 |
-
yield "📊
|
274 |
training_result = run_training()
|
275 |
yield training_result
|
276 |
|
277 |
elif function_name in ["general_query", "hard_query"]:
|
278 |
prompt_text = arguments["prompt"]
|
279 |
-
yield "🤖
|
280 |
-
#
|
281 |
response_generator = generate_response(
|
282 |
prompt=prompt_text,
|
283 |
chat_history=chat_history,
|
@@ -291,24 +293,24 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
291 |
yield response
|
292 |
|
293 |
else:
|
294 |
-
yield "⚠️
|
295 |
|
296 |
-
# ----------------------------
|
297 |
|
298 |
-
#
|
299 |
CHECKPOINT_DIR = "./checkpoints"
|
300 |
if not os.path.exists(CHECKPOINT_DIR):
|
301 |
os.makedirs(CHECKPOINT_DIR)
|
302 |
|
303 |
-
#
|
304 |
dataset = load_dataset('vntc/wiki-mini-corpus')
|
305 |
|
306 |
-
#
|
307 |
split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
|
308 |
train_dataset = split_dataset['train']
|
309 |
validation_dataset = split_dataset['test']
|
310 |
|
311 |
-
#
|
312 |
def preprocess_function(examples):
|
313 |
passages = [passage.lower().strip() for passage in examples['passage']]
|
314 |
return {'passage': passages}
|
@@ -316,7 +318,7 @@ def preprocess_function(examples):
|
|
316 |
processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
|
317 |
processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
|
318 |
|
319 |
-
#
|
320 |
if tokenizer.pad_token is None:
|
321 |
tokenizer.pad_token = tokenizer.eos_token
|
322 |
|
@@ -332,7 +334,7 @@ def tokenize_function(examples):
|
|
332 |
tokenized_train = processed_train.map(tokenize_function, batched=True)
|
333 |
tokenized_validation = processed_validation.map(tokenize_function, batched=True)
|
334 |
|
335 |
-
#
|
336 |
def add_labels(examples):
|
337 |
examples['labels'] = examples['input_ids'].copy()
|
338 |
return examples
|
@@ -340,31 +342,29 @@ def add_labels(examples):
|
|
340 |
tokenized_train = tokenized_train.map(add_labels, batched=True)
|
341 |
tokenized_validation = tokenized_validation.map(add_labels, batched=True)
|
342 |
|
343 |
-
#
|
344 |
tokenized_train = tokenized_train.remove_columns(['passage'])
|
345 |
tokenized_validation = tokenized_validation.remove_columns(['passage'])
|
346 |
|
347 |
-
#
|
348 |
tokenized_train.set_format('torch')
|
349 |
tokenized_validation.set_format('torch')
|
350 |
|
351 |
-
#
|
352 |
final_dataset = {
|
353 |
'train': tokenized_train,
|
354 |
'validation': tokenized_validation
|
355 |
}
|
356 |
|
357 |
-
#
|
358 |
class SaveCheckpointCallback(TrainerCallback):
|
359 |
-
def
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
# Load pretrained model with LoRA
|
368 |
pretrained = AutoModelForCausalLM.from_pretrained(
|
369 |
model_id,
|
370 |
device_map="auto",
|
@@ -374,30 +374,30 @@ pretrained = AutoModelForCausalLM.from_pretrained(
|
|
374 |
|
375 |
data_collator = DataCollatorForLanguageModeling(
|
376 |
tokenizer=tokenizer,
|
377 |
-
mlm=False, # Causal LM
|
378 |
pad_to_multiple_of=8
|
379 |
)
|
380 |
|
381 |
def get_step_done() -> int:
|
382 |
"""
|
383 |
-
|
384 |
|
385 |
Returns:
|
386 |
-
int:
|
387 |
"""
|
388 |
checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')]
|
389 |
if not checkpoints:
|
390 |
return 0
|
391 |
try:
|
392 |
-
#
|
393 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[1]))
|
394 |
step_done = int(latest_checkpoint.split('-')[1])
|
395 |
return step_done
|
396 |
except (IndexError, ValueError) as e:
|
397 |
-
print(f"
|
398 |
return 0
|
399 |
|
400 |
-
#
|
401 |
lora_config = LoraConfig(
|
402 |
r=8,
|
403 |
lora_alpha=32,
|
@@ -412,34 +412,34 @@ print(pretrained_model)
|
|
412 |
@spaces.GPU(duration=30, queue=False)
|
413 |
def run_training() -> str:
|
414 |
"""
|
415 |
-
|
416 |
|
417 |
Returns:
|
418 |
-
str:
|
419 |
"""
|
420 |
|
421 |
-
# TrainingArguments
|
422 |
training_args = TrainingArguments(
|
423 |
output_dir=CHECKPOINT_DIR,
|
424 |
per_device_train_batch_size=4,
|
425 |
per_device_eval_batch_size=4,
|
426 |
gradient_accumulation_steps=8,
|
427 |
num_train_epochs=3,
|
428 |
-
max_steps=300, #
|
429 |
learning_rate=3e-4,
|
430 |
weight_decay=0.01,
|
431 |
-
logging_steps=1, #
|
432 |
-
eval_strategy="steps", #
|
433 |
-
eval_steps=5, #
|
434 |
-
save_strategy="steps", #
|
435 |
-
save_steps=5, #
|
436 |
-
save_total_limit=5, #
|
437 |
-
fp16=True,
|
438 |
report_to="none",
|
439 |
load_best_model_at_end=True,
|
440 |
)
|
441 |
|
442 |
-
#
|
443 |
trainer = Trainer(
|
444 |
model=pretrained_model,
|
445 |
args=training_args,
|
@@ -447,50 +447,50 @@ def run_training() -> str:
|
|
447 |
eval_dataset=final_dataset['validation'],
|
448 |
tokenizer=tokenizer,
|
449 |
data_collator=data_collator,
|
450 |
-
callbacks=[SaveCheckpointCallback()], #
|
451 |
)
|
452 |
|
453 |
-
#
|
454 |
steps_done = get_step_done()
|
455 |
if steps_done > 0:
|
456 |
-
#
|
457 |
latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
|
458 |
if os.path.exists(latest_checkpoint):
|
459 |
-
print(f"
|
460 |
trainer.train(resume_from_checkpoint=latest_checkpoint)
|
461 |
else:
|
462 |
-
print(f"Checkpoint {latest_checkpoint}
|
463 |
trainer.train()
|
464 |
else:
|
465 |
trainer.train()
|
466 |
|
467 |
-
#
|
468 |
trainer.save_model(CHECKPOINT_DIR)
|
469 |
-
return "
|
470 |
|
471 |
-
#
|
472 |
@spaces.GPU(duration=30, queue=False)
|
473 |
def continuous_training(total_steps=300, steps_per_call=50):
|
474 |
"""
|
475 |
-
|
476 |
|
477 |
Args:
|
478 |
-
total_steps (int):
|
479 |
-
steps_per_call (int):
|
480 |
"""
|
481 |
steps_done = get_step_done()
|
482 |
while steps_done < total_steps:
|
483 |
remaining_steps = total_steps - steps_done
|
484 |
current_steps = min(steps_per_call, remaining_steps)
|
485 |
-
print(f"
|
486 |
|
487 |
-
#
|
488 |
training_args = TrainingArguments(
|
489 |
output_dir=CHECKPOINT_DIR,
|
490 |
per_device_train_batch_size=4,
|
491 |
per_device_eval_batch_size=4,
|
492 |
gradient_accumulation_steps=8,
|
493 |
-
num_train_epochs=1, #
|
494 |
max_steps=current_steps,
|
495 |
learning_rate=3e-4,
|
496 |
weight_decay=0.01,
|
@@ -505,7 +505,7 @@ def continuous_training(total_steps=300, steps_per_call=50):
|
|
505 |
load_best_model_at_end=True,
|
506 |
)
|
507 |
|
508 |
-
#
|
509 |
trainer = Trainer(
|
510 |
model=pretrained_model,
|
511 |
args=training_args,
|
@@ -516,30 +516,30 @@ def continuous_training(total_steps=300, steps_per_call=50):
|
|
516 |
callbacks=[SaveCheckpointCallback()],
|
517 |
)
|
518 |
|
519 |
-
#
|
520 |
if steps_done > 0:
|
521 |
latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
|
522 |
if os.path.exists(latest_checkpoint):
|
523 |
-
print(f"
|
524 |
trainer.train(resume_from_checkpoint=latest_checkpoint)
|
525 |
else:
|
526 |
-
print(f"Checkpoint {latest_checkpoint}
|
527 |
trainer.train()
|
528 |
else:
|
529 |
trainer.train()
|
530 |
|
531 |
steps_done = get_step_done()
|
532 |
-
print(f"
|
533 |
|
534 |
-
#
|
535 |
if steps_done >= total_steps:
|
536 |
-
print("
|
537 |
break
|
538 |
|
539 |
-
#
|
540 |
-
time.sleep(2) #
|
541 |
|
542 |
-
# ---------------------------- Gradio
|
543 |
|
544 |
@spaces.GPU(duration=30, queue=False)
|
545 |
def generate(
|
@@ -552,29 +552,29 @@ def generate(
|
|
552 |
repetition_penalty: float = 1.2,
|
553 |
) -> Iterator[str]:
|
554 |
"""
|
555 |
-
|
556 |
"""
|
557 |
-
#
|
558 |
-
yield "🔍
|
559 |
|
560 |
-
#
|
561 |
function_call = process_query(message)
|
562 |
|
563 |
-
#
|
564 |
if function_call["name"] == "web_search":
|
565 |
-
yield "🛠️
|
566 |
elif function_call["name"] == "summarize_query":
|
567 |
-
yield "🛠️
|
568 |
elif function_call["name"] == "sentiment_analysis":
|
569 |
-
yield "🛠️
|
570 |
elif function_call["name"] in ["general_query", "hard_query"]:
|
571 |
-
yield "🛠️
|
572 |
elif function_call["name"] == "train_model":
|
573 |
-
yield "🛠️
|
574 |
else:
|
575 |
-
yield "⚠️
|
576 |
|
577 |
-
#
|
578 |
response_iterator = handle_functions(
|
579 |
function_call=function_call,
|
580 |
prompt=message,
|
@@ -589,40 +589,40 @@ def generate(
|
|
589 |
for response in response_iterator:
|
590 |
yield response
|
591 |
|
592 |
-
#
|
593 |
EXAMPLES = [
|
594 |
-
["
|
595 |
-
["
|
596 |
-
["
|
597 |
-
["
|
598 |
-
["
|
599 |
-
["
|
600 |
-
["
|
601 |
-
["
|
602 |
-
["
|
603 |
-
["
|
604 |
]
|
605 |
|
606 |
-
#
|
607 |
chat_interface = gr.ChatInterface(
|
608 |
-
fn=generate, #
|
609 |
additional_inputs=[
|
610 |
gr.Slider(
|
611 |
-
label="
|
612 |
minimum=1,
|
613 |
maximum=MAX_MAX_NEW_TOKENS,
|
614 |
step=1,
|
615 |
value=DEFAULT_MAX_NEW_TOKENS,
|
616 |
),
|
617 |
gr.Slider(
|
618 |
-
label="
|
619 |
minimum=0.1,
|
620 |
maximum=4.0,
|
621 |
step=0.1,
|
622 |
value=0.6,
|
623 |
),
|
624 |
gr.Slider(
|
625 |
-
label="Top-p (
|
626 |
minimum=0.05,
|
627 |
maximum=1.0,
|
628 |
step=0.05,
|
@@ -636,45 +636,45 @@ chat_interface = gr.ChatInterface(
|
|
636 |
value=50,
|
637 |
),
|
638 |
gr.Slider(
|
639 |
-
label="
|
640 |
minimum=1.0,
|
641 |
maximum=2.0,
|
642 |
step=0.05,
|
643 |
value=1.2,
|
644 |
),
|
645 |
],
|
646 |
-
stop_btn=None, #
|
647 |
-
examples=EXAMPLES, #
|
648 |
-
cache_examples=False, #
|
649 |
title="🤖 OpenGPT-4o Chatbot",
|
650 |
-
description="
|
651 |
-
theme="default", #
|
652 |
)
|
653 |
|
654 |
-
#
|
655 |
with gr.Blocks(css="""
|
656 |
.gradio-container {
|
657 |
-
background-color: #f0f2f5; /*
|
658 |
}
|
659 |
.gradio-container h1 {
|
660 |
-
color: #4a90e2; /*
|
661 |
}
|
662 |
.gradio-container .gr-button {
|
663 |
-
background-color: #4a90e2; /*
|
664 |
-
color: white; /*
|
665 |
}
|
666 |
.gradio-container .gr-slider__label {
|
667 |
-
color: #333333; /*
|
668 |
}
|
669 |
.gradio-container .gr-chatbot {
|
670 |
-
border: 2px solid #4a90e2; /*
|
671 |
-
border-radius: 10px; /*
|
672 |
-
padding: 10px; /*
|
673 |
-
background-color: #ffffff; /*
|
674 |
}
|
675 |
""", fill_height=True) as demo:
|
676 |
-
gr.Markdown(DESCRIPTION) #
|
677 |
-
chat_interface.render() #
|
678 |
|
679 |
if __name__ == "__main__":
|
680 |
-
demo.queue(max_size=30).launch() #
|
|
|
23 |
from peft import LoraConfig, get_peft_model
|
24 |
import time
|
25 |
|
26 |
+
# ---------------------------- Cấu Hình ---------------------------- #
|
27 |
|
28 |
DESCRIPTION = """\
|
29 |
+
# Llama 3.2 3B Instruct với Chức Năng Nâng Cao
|
30 |
|
31 |
+
Llama 3.2 3B là phiên bản mới nhất của Meta về các mô hình ngôn ngữ mở.
|
32 |
+
Demo này giới thiệu [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), được tinh chỉnh để theo dõi hướng dẫn.
|
33 |
+
Để biết thêm chi tiết, vui lòng xem [bài viết của chúng tôi](https://huggingface.co/blog/llama32).
|
34 |
"""
|
35 |
|
36 |
+
MAX_MAX_NEW_TOKENS = 2048 # Số token tối đa có thể tạo ra
|
37 |
+
DEFAULT_MAX_NEW_TOKENS = 1024 # Số token tạo ra mặc định
|
38 |
+
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000")) # Độ dài token tối đa cho đầu vào
|
39 |
|
40 |
+
# Xác định thiết bị sử dụng (GPU nếu có, ngược lại CPU)
|
41 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
42 |
|
43 |
+
model_id = "meta-llama/Llama-3.2-3B-Instruct" # ID mô hình
|
44 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
45 |
model = AutoModelForCausalLM.from_pretrained(
|
46 |
model_id,
|
47 |
device_map="auto",
|
48 |
+
torch_dtype=torch.float16, # Sử dụng float16 để tương thích với fp16=True
|
49 |
)
|
50 |
model.to(device)
|
51 |
model.eval()
|
52 |
|
53 |
+
# Khởi tạo pipeline phân tích tâm lý trên GPU nếu có
|
54 |
sentiment_pipeline = pipeline(
|
55 |
"sentiment-analysis",
|
56 |
model="nlptown/bert-base-multilingual-uncased-sentiment",
|
57 |
device=0 if torch.cuda.is_available() else -1
|
58 |
)
|
59 |
|
60 |
+
# ---------------------------- Định Nghĩa Hàm ---------------------------- #
|
61 |
|
62 |
@lru_cache(maxsize=128)
|
63 |
def extract_text_from_webpage(html_content: str) -> str:
|
64 |
+
"""Trích xuất văn bản hiển thị từ nội dung HTML sử dụng BeautifulSoup."""
|
65 |
soup = BeautifulSoup(html_content, "html.parser")
|
66 |
+
# Loại bỏ các thẻ không hiển thị như script, style, header, footer, nav, form, svg
|
67 |
for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
|
68 |
tag.extract()
|
69 |
+
# Trích xuất văn bản hiển thị, tách bằng dấu cách và loại bỏ khoảng trắng thừa
|
70 |
visible_text = soup.get_text(separator=' ', strip=True)
|
71 |
return visible_text
|
72 |
|
73 |
def search(query: str) -> List[Dict[str, Any]]:
|
74 |
+
"""Thực hiện tìm kiếm trên Google và trả về kết quả."""
|
75 |
term = query
|
76 |
all_results = []
|
77 |
+
max_chars_per_page = 8000 # Số ký tự tối đa mỗi trang
|
78 |
headers = {
|
79 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
|
80 |
}
|
|
|
83 |
resp = session.get(
|
84 |
url="https://www.google.com/search",
|
85 |
headers=headers,
|
86 |
+
params={"q": term, "num": 4}, # Tìm kiếm với 4 kết quả mỗi trang
|
87 |
timeout=5,
|
88 |
+
verify=False, # Bỏ qua xác minh SSL
|
89 |
)
|
90 |
resp.raise_for_status()
|
91 |
soup = BeautifulSoup(resp.text, "html.parser")
|
92 |
+
result_blocks = soup.find_all("div", attrs={"class": "g"}) # Tìm tất cả các khối kết quả
|
93 |
for result in result_blocks:
|
94 |
+
link_tag = result.find("a", href=True) # Tìm thẻ liên kết
|
95 |
if link_tag and 'href' in link_tag.attrs:
|
96 |
link = link_tag["href"]
|
97 |
try:
|
|
|
104 |
webpage.raise_for_status()
|
105 |
visible_text = extract_text_from_webpage(webpage.text)
|
106 |
if len(visible_text) > max_chars_per_page:
|
107 |
+
visible_text = visible_text[:max_chars_per_page] # Cắt văn bản nếu quá dài
|
108 |
all_results.append({"link": link, "text": visible_text})
|
109 |
except requests.exceptions.RequestException:
|
110 |
+
all_results.append({"link": link, "text": "Không thể lấy nội dung."})
|
111 |
except requests.exceptions.RequestException as e:
|
112 |
+
all_results.append({"link": "N/A", "text": "Không thể thực hiện tìm kiếm."})
|
113 |
return all_results
|
114 |
|
115 |
def summarize_text(text: str, max_length: int = 150) -> str:
|
116 |
+
"""Tóm tắt văn bản sử dụng mô hình Llama."""
|
117 |
conversation = [
|
118 |
+
{"role": "user", "content": f"Hãy tóm tắt đoạn văn sau: {text}"}
|
119 |
]
|
120 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
121 |
input_ids = input_ids.to(device)
|
|
|
138 |
return summary
|
139 |
|
140 |
def analyze_sentiment(text: str) -> str:
|
141 |
+
"""Phân tích tâm lý của văn bản sử dụng mô hình."""
|
142 |
result = sentiment_pipeline(text)
|
143 |
sentiment = result[0]['label']
|
144 |
score = result[0]['score']
|
145 |
+
return f"🟢 **Tâm lý**: {sentiment} (Điểm: {score:.2f})"
|
146 |
|
147 |
def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
148 |
"""
|
149 |
+
Tạo phản hồi sử dụng mô hình Llama cục bộ theo chế độ streaming.
|
150 |
"""
|
151 |
+
# Xây dựng lịch sử cuộc trò chuyện
|
152 |
conversation = []
|
153 |
for user, assistant in chat_history:
|
154 |
conversation.extend([
|
155 |
{"role": "user", "content": user},
|
156 |
{"role": "assistant", "content": assistant},
|
157 |
])
|
158 |
+
conversation.append({"role": "user", "content": prompt}) # Thêm tin nhắn của người dùng
|
159 |
|
160 |
+
# Chuẩn bị input_ids từ tokenizer
|
161 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
162 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
163 |
+
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # Cắt input nếu quá dài
|
164 |
+
gr.Warning(f"Đã cắt bỏ phần cuộc trò chuyện vì vượt quá {MAX_INPUT_TOKEN_LENGTH} token.")
|
165 |
+
input_ids = input_ids.to(device) # Di chuyển input tới thiết bị
|
166 |
|
167 |
+
# Khởi tạo streamer để nhận văn bản được tạo ra theo thời gian thực
|
168 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
169 |
generate_kwargs = {
|
170 |
"input_ids": input_ids,
|
|
|
177 |
"num_beams": 1,
|
178 |
"repetition_penalty": repetition_penalty,
|
179 |
}
|
180 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs) # Tạo luồng để sinh văn bản
|
181 |
t.start()
|
182 |
|
183 |
+
# Stream văn bản được tạo ra
|
184 |
outputs = []
|
185 |
for text in streamer:
|
186 |
outputs.append(text)
|
|
|
189 |
@lru_cache(maxsize=128)
|
190 |
def process_query(query: str) -> Dict[str, Any]:
|
191 |
"""
|
192 |
+
Xác định hàm nào sẽ được gọi dựa trên truy vấn của người dùng.
|
193 |
"""
|
194 |
+
# Định nghĩa các từ khóa hoặc mẫu để xác định hàm
|
195 |
web_search_keywords = ["search", "find", "lookup", "google"]
|
196 |
general_query_keywords = ["explain", "describe", "tell me about", "what is", "how to"]
|
197 |
summarize_keywords = ["summarize", "summarise", "brief", "short"]
|
|
|
226 |
|
227 |
def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
228 |
"""
|
229 |
+
Thực thi hàm phù hợp dựa trên lời gọi hàm.
|
230 |
"""
|
231 |
function_name = function_call["name"]
|
232 |
arguments = function_call["arguments"]
|
233 |
|
234 |
if function_name == "web_search":
|
235 |
query = arguments["query"]
|
236 |
+
yield "🔍 Đang thực hiện tìm kiếm trên web..."
|
237 |
web_results = search(query)
|
238 |
if not web_results:
|
239 |
+
yield "⚠️ Không tìm thấy kết quả."
|
240 |
return
|
241 |
+
# Tóm tắt kết quả tìm kiếm
|
242 |
+
web_summary = '\n\n'.join([f"🔗 **Liên kết**: {res['link']}\n📝 **Mô tả**: {res['text']}" for res in web_results if res["text"] != "Không thể lấy nội dung."])
|
243 |
if not web_summary:
|
244 |
+
web_summary = "⚠️ Không thể lấy nội dung từ kết quả tìm kiếm."
|
245 |
|
246 |
+
# Trả về kết quả tìm kiếm cho người dùng
|
247 |
+
yield "📄 **Kết quả tìm kiếm:**\n" + web_summary
|
248 |
|
249 |
elif function_name == "summarize_query":
|
250 |
+
# Khi người dùng yêu cầu tóm tắt, hệ thống sẽ thực hiện tìm kiếm và sau đó tóm tắt kết quả
|
251 |
query = arguments["prompt"]
|
252 |
+
yield "🔍 Đang thực hiện tìm kiếm để tóm tắt..."
|
253 |
web_results = search(query)
|
254 |
if not web_results:
|
255 |
+
yield "⚠️ Không tìm thấy kết quả để tóm tắt."
|
256 |
return
|
257 |
+
# Lấy nội dung từ kết quả tìm kiếm để tóm tắt
|
258 |
+
combined_text = ' '.join([res['text'] for res in web_results if res['text'] != "Không thể lấy nội dung."])
|
259 |
if not combined_text:
|
260 |
+
yield "⚠️ Không có nội dung để tóm tắt."
|
261 |
return
|
262 |
+
# Tóm tắt nội dung đã lấy
|
263 |
+
yield "📝 Đang tóm tắt thông tin..."
|
264 |
summary = summarize_text(combined_text)
|
265 |
+
yield "📄 **Tóm tắt:**\n" + summary
|
266 |
|
267 |
elif function_name == "sentiment_analysis":
|
268 |
prompt_text = arguments["prompt"]
|
269 |
+
yield "📊 Đang phân tích tâm lý..."
|
270 |
sentiment = analyze_sentiment(prompt_text)
|
271 |
yield sentiment
|
272 |
|
273 |
elif function_name == "train_model":
|
274 |
prompt_text = arguments["prompt"]
|
275 |
+
yield "📊 Đang huấn luyện mô hình..."
|
276 |
training_result = run_training()
|
277 |
yield training_result
|
278 |
|
279 |
elif function_name in ["general_query", "hard_query"]:
|
280 |
prompt_text = arguments["prompt"]
|
281 |
+
yield "🤖 Đang tạo phản hồi..."
|
282 |
+
# Tạo phản hồi sử dụng mô hình Llama
|
283 |
response_generator = generate_response(
|
284 |
prompt=prompt_text,
|
285 |
chat_history=chat_history,
|
|
|
293 |
yield response
|
294 |
|
295 |
else:
|
296 |
+
yield "⚠️ Lời gọi hàm không được nhận dạng."
|
297 |
|
298 |
+
# ---------------------------- Huấn luyện ---------------------------- #
|
299 |
|
300 |
+
# Đường dẫn lưu checkpoint
|
301 |
CHECKPOINT_DIR = "./checkpoints"
|
302 |
if not os.path.exists(CHECKPOINT_DIR):
|
303 |
os.makedirs(CHECKPOINT_DIR)
|
304 |
|
305 |
+
# Tải Dataset (CPU)
|
306 |
dataset = load_dataset('vntc/wiki-mini-corpus')
|
307 |
|
308 |
+
# Chia Dataset thành train và validation (CPU)
|
309 |
split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
|
310 |
train_dataset = split_dataset['train']
|
311 |
validation_dataset = split_dataset['test']
|
312 |
|
313 |
+
# Tiền Xử Lý Văn Bản (CPU)
|
314 |
def preprocess_function(examples):
|
315 |
passages = [passage.lower().strip() for passage in examples['passage']]
|
316 |
return {'passage': passages}
|
|
|
318 |
processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
|
319 |
processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
|
320 |
|
321 |
+
# Đảm bảo tokenizer có pad_token
|
322 |
if tokenizer.pad_token is None:
|
323 |
tokenizer.pad_token = tokenizer.eos_token
|
324 |
|
|
|
334 |
tokenized_train = processed_train.map(tokenize_function, batched=True)
|
335 |
tokenized_validation = processed_validation.map(tokenize_function, batched=True)
|
336 |
|
337 |
+
# Thêm trường 'labels' (CPU)
|
338 |
def add_labels(examples):
|
339 |
examples['labels'] = examples['input_ids'].copy()
|
340 |
return examples
|
|
|
342 |
tokenized_train = tokenized_train.map(add_labels, batched=True)
|
343 |
tokenized_validation = tokenized_validation.map(add_labels, batched=True)
|
344 |
|
345 |
+
# Loại bỏ các cột không cần thiết (CPU)
|
346 |
tokenized_train = tokenized_train.remove_columns(['passage'])
|
347 |
tokenized_validation = tokenized_validation.remove_columns(['passage'])
|
348 |
|
349 |
+
# Định dạng dữ liệu cho PyTorch (CPU)
|
350 |
tokenized_train.set_format('torch')
|
351 |
tokenized_validation.set_format('torch')
|
352 |
|
353 |
+
# Tạo DatasetDict (CPU)
|
354 |
final_dataset = {
|
355 |
'train': tokenized_train,
|
356 |
'validation': tokenized_validation
|
357 |
}
|
358 |
|
359 |
+
# Định Nghĩa TrainerCallback để Lưu Checkpoint Nhanh Hơn
|
360 |
class SaveCheckpointCallback(TrainerCallback):
|
361 |
+
def on_save(self, args, state, control, **kwargs):
|
362 |
+
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
363 |
+
print(f"Lưu checkpoint tại: {checkpoint_path}")
|
364 |
+
kwargs['trainer'].save_model(checkpoint_path)
|
365 |
+
return control # Trả về đối tượng control hiện tại
|
366 |
+
|
367 |
+
# Tải mô hình đã được pretrained
|
|
|
|
|
368 |
pretrained = AutoModelForCausalLM.from_pretrained(
|
369 |
model_id,
|
370 |
device_map="auto",
|
|
|
374 |
|
375 |
data_collator = DataCollatorForLanguageModeling(
|
376 |
tokenizer=tokenizer,
|
377 |
+
mlm=False, # Vì bạn đang thực hiện Causal LM
|
378 |
pad_to_multiple_of=8
|
379 |
)
|
380 |
|
381 |
def get_step_done() -> int:
|
382 |
"""
|
383 |
+
Lấy số bước huấn luyện đã hoàn thành từ checkpoint mới nhất trong thư mục lưu trữ.
|
384 |
|
385 |
Returns:
|
386 |
+
int: Số bước đã hoàn thành. Trả về 0 nếu không tìm thấy checkpoint.
|
387 |
"""
|
388 |
checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')]
|
389 |
if not checkpoints:
|
390 |
return 0
|
391 |
try:
|
392 |
+
# Tìm checkpoint mới nhất dựa trên số bước
|
393 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[1]))
|
394 |
step_done = int(latest_checkpoint.split('-')[1])
|
395 |
return step_done
|
396 |
except (IndexError, ValueError) as e:
|
397 |
+
print(f"Lỗi khi phân tích tên checkpoint: {e}")
|
398 |
return 0
|
399 |
|
400 |
+
# Tải và Cấu Hình Mô Hình với LoRA (GPU)
|
401 |
lora_config = LoraConfig(
|
402 |
r=8,
|
403 |
lora_alpha=32,
|
|
|
412 |
@spaces.GPU(duration=30, queue=False)
|
413 |
def run_training() -> str:
|
414 |
"""
|
415 |
+
Hàm huấn luyện mô hình sử dụng GPU với thời gian hạn chế.
|
416 |
|
417 |
Returns:
|
418 |
+
str: Thông báo kết quả huấn luyện.
|
419 |
"""
|
420 |
|
421 |
+
# Cấu Hình TrainingArguments (GPU)
|
422 |
training_args = TrainingArguments(
|
423 |
output_dir=CHECKPOINT_DIR,
|
424 |
per_device_train_batch_size=4,
|
425 |
per_device_eval_batch_size=4,
|
426 |
gradient_accumulation_steps=8,
|
427 |
num_train_epochs=3,
|
428 |
+
max_steps=300, # Đặt tổng số bước huấn luyện
|
429 |
learning_rate=3e-4,
|
430 |
weight_decay=0.01,
|
431 |
+
logging_steps=1, # Ghi log sau mỗi bước
|
432 |
+
eval_strategy="steps", # Đánh giá sau mỗi vài bước
|
433 |
+
eval_steps=5, # Đánh giá sau mỗi 5 bước
|
434 |
+
save_strategy="steps", # Lưu checkpoint sau mỗi vài bước
|
435 |
+
save_steps=5, # Lưu checkpoint sau mỗi 5 bước
|
436 |
+
save_total_limit=5, # Giới hạn số lượng checkpoint lưu trữ
|
437 |
+
fp16=True, # Kích hoạt huấn luyện hỗn hợp độ chính xác
|
438 |
report_to="none",
|
439 |
load_best_model_at_end=True,
|
440 |
)
|
441 |
|
442 |
+
# Tạo Trainer (GPU)
|
443 |
trainer = Trainer(
|
444 |
model=pretrained_model,
|
445 |
args=training_args,
|
|
|
447 |
eval_dataset=final_dataset['validation'],
|
448 |
tokenizer=tokenizer,
|
449 |
data_collator=data_collator,
|
450 |
+
callbacks=[SaveCheckpointCallback()], # Thêm callback
|
451 |
)
|
452 |
|
453 |
+
# Kiểm tra nếu có checkpoint
|
454 |
steps_done = get_step_done()
|
455 |
if steps_done > 0:
|
456 |
+
# Xác định checkpoint mới nhất dựa trên số bước
|
457 |
latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
|
458 |
if os.path.exists(latest_checkpoint):
|
459 |
+
print(f"Đang tiếp tục huấn luyện từ checkpoint: {latest_checkpoint}")
|
460 |
trainer.train(resume_from_checkpoint=latest_checkpoint)
|
461 |
else:
|
462 |
+
print(f"Checkpoint {latest_checkpoint} không tồn tại. Bắt đầu huấn luyện từ đầu.")
|
463 |
trainer.train()
|
464 |
else:
|
465 |
trainer.train()
|
466 |
|
467 |
+
# Lưu checkpoint sau khi huấn luyện
|
468 |
trainer.save_model(CHECKPOINT_DIR)
|
469 |
+
return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
|
470 |
|
471 |
+
# Hàm Tự Động Hóa Việc Gọi Lặp Lại Hàm Huấn Luyện
|
472 |
@spaces.GPU(duration=30, queue=False)
|
473 |
def continuous_training(total_steps=300, steps_per_call=50):
|
474 |
"""
|
475 |
+
Hàm tự động gọi lại `run_training` để hoàn thành quá trình huấn luyện.
|
476 |
|
477 |
Args:
|
478 |
+
total_steps (int): Tổng số bước huấn luyện mong muốn.
|
479 |
+
steps_per_call (int): Số bước huấn luyện mỗi lần gọi hàm.
|
480 |
"""
|
481 |
steps_done = get_step_done()
|
482 |
while steps_done < total_steps:
|
483 |
remaining_steps = total_steps - steps_done
|
484 |
current_steps = min(steps_per_call, remaining_steps)
|
485 |
+
print(f"Bắt đầu huấn luyện cho {current_steps} bước.")
|
486 |
|
487 |
+
# Cập nhật TrainingArguments để huấn luyện cho current_steps bước
|
488 |
training_args = TrainingArguments(
|
489 |
output_dir=CHECKPOINT_DIR,
|
490 |
per_device_train_batch_size=4,
|
491 |
per_device_eval_batch_size=4,
|
492 |
gradient_accumulation_steps=8,
|
493 |
+
num_train_epochs=1, # Huấn luyện trong một epoch
|
494 |
max_steps=current_steps,
|
495 |
learning_rate=3e-4,
|
496 |
weight_decay=0.01,
|
|
|
505 |
load_best_model_at_end=True,
|
506 |
)
|
507 |
|
508 |
+
# Tạo Trainer với TrainingArguments mới
|
509 |
trainer = Trainer(
|
510 |
model=pretrained_model,
|
511 |
args=training_args,
|
|
|
516 |
callbacks=[SaveCheckpointCallback()],
|
517 |
)
|
518 |
|
519 |
+
# Tiếp tục huấn luyện từ checkpoint hiện tại
|
520 |
if steps_done > 0:
|
521 |
latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
|
522 |
if os.path.exists(latest_checkpoint):
|
523 |
+
print(f"Đang tiếp tục huấn luyện từ checkpoint: {latest_checkpoint}")
|
524 |
trainer.train(resume_from_checkpoint=latest_checkpoint)
|
525 |
else:
|
526 |
+
print(f"Checkpoint {latest_checkpoint} không tồn tại. Bắt đầu huấn luyện từ đầu.")
|
527 |
trainer.train()
|
528 |
else:
|
529 |
trainer.train()
|
530 |
|
531 |
steps_done = get_step_done()
|
532 |
+
print(f"Đã huấn luyện {steps_done} / {total_steps} bước.")
|
533 |
|
534 |
+
# Kiểm tra nếu đã đạt số bước mong muốn
|
535 |
if steps_done >= total_steps:
|
536 |
+
print("Đã hoàn thành toàn bộ quá trình huấn luyện.")
|
537 |
break
|
538 |
|
539 |
+
# Chờ một khoảng thời gian trước khi gọi lại (tùy thuộc vào yêu cầu của hệ thống)
|
540 |
+
time.sleep(2) # Thời gian chờ có thể điều chỉnh
|
541 |
|
542 |
+
# ---------------------------- Giao Diện Gradio ---------------------------- #
|
543 |
|
544 |
@spaces.GPU(duration=30, queue=False)
|
545 |
def generate(
|
|
|
552 |
repetition_penalty: float = 1.2,
|
553 |
) -> Iterator[str]:
|
554 |
"""
|
555 |
+
Hàm chính để xử lý đầu vào của người dùng và tạo phản hồi.
|
556 |
"""
|
557 |
+
# Thông báo về việc phân tích đầu vào
|
558 |
+
yield "🔍 Đang phân tích truy vấn của bạn..."
|
559 |
|
560 |
+
# Xác định hàm nào sẽ được gọi dựa trên tin nhắn của người dùng
|
561 |
function_call = process_query(message)
|
562 |
|
563 |
+
# Thông báo về hàm được chọn
|
564 |
if function_call["name"] == "web_search":
|
565 |
+
yield "🛠️ Đã chọn chức năng: Tìm kiếm trên web."
|
566 |
elif function_call["name"] == "summarize_query":
|
567 |
+
yield "🛠️ Đã chọn chức năng: Tóm tắt văn bản."
|
568 |
elif function_call["name"] == "sentiment_analysis":
|
569 |
+
yield "🛠️ Đã chọn chức năng: Phân tích tâm lý."
|
570 |
elif function_call["name"] in ["general_query", "hard_query"]:
|
571 |
+
yield "🛠️ Đã chọn chức năng: Trả lời câu hỏi."
|
572 |
elif function_call["name"] == "train_model":
|
573 |
+
yield "🛠️ Đã chọn chức năng: Huấn luyện mô hình."
|
574 |
else:
|
575 |
+
yield "⚠️ Không thể xác định chức năng phù hợp."
|
576 |
|
577 |
+
# Xử lý lời gọi hàm và sinh phản hồi tương ứng
|
578 |
response_iterator = handle_functions(
|
579 |
function_call=function_call,
|
580 |
prompt=message,
|
|
|
589 |
for response in response_iterator:
|
590 |
yield response
|
591 |
|
592 |
+
# Định nghĩa các ví dụ để hướng dẫn người dùng
|
593 |
EXAMPLES = [
|
594 |
+
["Xin chào! Bạn khỏe không?"],
|
595 |
+
["Bạn có thể giải thích ngắn gọn về ngôn ngữ lập trình Python không?"],
|
596 |
+
["Giải thích cốt truyện của Cô bé Lọ Lem trong một câu."],
|
597 |
+
["Một người đàn ông cần bao nhiêu giờ để ăn một chiếc máy bay trực thăng?"],
|
598 |
+
["Viết một bài báo 100 từ về 'Lợi ích của mã nguồn mở trong nghiên cứu AI'"],
|
599 |
+
["Tìm và cung cấp cho tôi tin tức mới nhất về năng lượng tái tạo."],
|
600 |
+
["Tìm thông tin về Rạn san hô Great Barrier Reef."],
|
601 |
+
["Tóm tắt nội dung về trí tuệ nhân tạo."],
|
602 |
+
["Phân tích tâm lý của đoạn văn sau: Tôi rất vui khi được gặp bạn hôm nay!"],
|
603 |
+
["Huấn luyện mô hình!"],
|
604 |
]
|
605 |
|
606 |
+
# Cấu hình giao diện trò chuyện của Gradio với giao diện đẹp mắt
|
607 |
chat_interface = gr.ChatInterface(
|
608 |
+
fn=generate, # Hàm được gọi khi có tương tác từ người dùng
|
609 |
additional_inputs=[
|
610 |
gr.Slider(
|
611 |
+
label="Số token mới tối đa",
|
612 |
minimum=1,
|
613 |
maximum=MAX_MAX_NEW_TOKENS,
|
614 |
step=1,
|
615 |
value=DEFAULT_MAX_NEW_TOKENS,
|
616 |
),
|
617 |
gr.Slider(
|
618 |
+
label="Nhiệt độ",
|
619 |
minimum=0.1,
|
620 |
maximum=4.0,
|
621 |
step=0.1,
|
622 |
value=0.6,
|
623 |
),
|
624 |
gr.Slider(
|
625 |
+
label="Top-p (nucleus sampling)",
|
626 |
minimum=0.05,
|
627 |
maximum=1.0,
|
628 |
step=0.05,
|
|
|
636 |
value=50,
|
637 |
),
|
638 |
gr.Slider(
|
639 |
+
label="Hình phạt sự lặp lại",
|
640 |
minimum=1.0,
|
641 |
maximum=2.0,
|
642 |
step=0.05,
|
643 |
value=1.2,
|
644 |
),
|
645 |
],
|
646 |
+
stop_btn=None, # Không có nút dừng
|
647 |
+
examples=EXAMPLES, # Các ví dụ được hiển thị cho người dùng
|
648 |
+
cache_examples=False, # Không lưu bộ nhớ cache cho các ví dụ
|
649 |
title="🤖 OpenGPT-4o Chatbot",
|
650 |
+
description="Một trợ lý AI mạnh mẽ sử dụng mô hình Llama-3.2 cục bộ với các chức năng tìm kiếm web, tóm tắt văn bản và phân tích tâm lý.",
|
651 |
+
theme="default", # Có thể thay đổi theme để giao diện đẹp hơn
|
652 |
)
|
653 |
|
654 |
+
# Tạo giao diện chính của Gradio với CSS tùy chỉnh
|
655 |
with gr.Blocks(css="""
|
656 |
.gradio-container {
|
657 |
+
background-color: #f0f2f5; /* Màu nền nhẹ nhàng */
|
658 |
}
|
659 |
.gradio-container h1 {
|
660 |
+
color: #4a90e2; /* Màu xanh dương cho tiêu đề */
|
661 |
}
|
662 |
.gradio-container .gr-button {
|
663 |
+
background-color: #4a90e2; /* Màu xanh dương cho nút */
|
664 |
+
color: white; /* Màu chữ trắng trên nút */
|
665 |
}
|
666 |
.gradio-container .gr-slider__label {
|
667 |
+
color: #333333; /* Màu chữ đen cho nhãn slider */
|
668 |
}
|
669 |
.gradio-container .gr-chatbot {
|
670 |
+
border: 2px solid #4a90e2; /* Viền xanh dương cho chatbot */
|
671 |
+
border-radius: 10px; /* Bo góc viền chatbot */
|
672 |
+
padding: 10px; /* Khoảng cách bên trong chatbot */
|
673 |
+
background-color: #ffffff; /* Màu nền trắng cho chatbot */
|
674 |
}
|
675 |
""", fill_height=True) as demo:
|
676 |
+
gr.Markdown(DESCRIPTION) # Hiển thị mô tả
|
677 |
+
chat_interface.render() # Hiển thị giao diện trò chuyện
|
678 |
|
679 |
if __name__ == "__main__":
|
680 |
+
demo.queue(max_size=30).launch() # Khởi chạy ứng dụng Gradio với hàng đợi kích thước tối đa là 30
|