Spaces:
Sleeping
Sleeping
Commit
·
0abfd6d
1
Parent(s):
1655dfc
fix errors
Browse files
app.py
CHANGED
@@ -23,54 +23,56 @@ from datasets import load_dataset
|
|
23 |
from peft import LoraConfig, get_peft_model
|
24 |
import time
|
25 |
|
26 |
-
# ----------------------------
|
27 |
|
28 |
DESCRIPTION = """\
|
29 |
-
# Llama 3.2 3B Instruct
|
30 |
|
31 |
-
Llama 3.2 3B
|
32 |
-
|
33 |
-
|
34 |
"""
|
35 |
|
36 |
-
MAX_MAX_NEW_TOKENS = 2048 #
|
37 |
-
DEFAULT_MAX_NEW_TOKENS = 1024 #
|
38 |
-
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000")) #
|
39 |
|
40 |
-
#
|
41 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
42 |
|
43 |
-
model_id = "meta-llama/Llama-3.2-3B-Instruct" #
|
44 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
45 |
model = AutoModelForCausalLM.from_pretrained(
|
46 |
model_id,
|
47 |
device_map="auto",
|
48 |
-
torch_dtype=torch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
)
|
50 |
-
model.to(device) # Di chuyển mô hình tới thiết bị đã chọn
|
51 |
-
model.eval() # Đặt mô hình ở chế độ đánh giá
|
52 |
-
|
53 |
-
# Khởi tạo pipeline phân tích tâm lý
|
54 |
-
sentiment_pipeline = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
|
55 |
|
56 |
-
# ----------------------------
|
57 |
|
58 |
@lru_cache(maxsize=128)
|
59 |
def extract_text_from_webpage(html_content: str) -> str:
|
60 |
-
"""
|
61 |
soup = BeautifulSoup(html_content, "html.parser")
|
62 |
-
# Loại bỏ các thẻ không hiển thị như script, style, header, footer, nav, form, svg
|
63 |
for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
|
64 |
tag.extract()
|
65 |
-
# Trích xuất văn bản hiển thị, tách bằng dấu cách và loại bỏ khoảng trắng thừa
|
66 |
visible_text = soup.get_text(separator=' ', strip=True)
|
67 |
return visible_text
|
68 |
|
69 |
def search(query: str) -> List[Dict[str, Any]]:
|
70 |
-
"""
|
71 |
term = query
|
72 |
all_results = []
|
73 |
-
max_chars_per_page = 8000 #
|
74 |
headers = {
|
75 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
|
76 |
}
|
@@ -79,15 +81,15 @@ def search(query: str) -> List[Dict[str, Any]]:
|
|
79 |
resp = session.get(
|
80 |
url="https://www.google.com/search",
|
81 |
headers=headers,
|
82 |
-
params={"q": term, "num": 4}, #
|
83 |
timeout=5,
|
84 |
-
verify=False, #
|
85 |
)
|
86 |
-
resp.raise_for_status()
|
87 |
soup = BeautifulSoup(resp.text, "html.parser")
|
88 |
-
result_blocks = soup.find_all("div", attrs={"class": "g"})
|
89 |
for result in result_blocks:
|
90 |
-
link_tag = result.find("a", href=True)
|
91 |
if link_tag and 'href' in link_tag.attrs:
|
92 |
link = link_tag["href"]
|
93 |
try:
|
@@ -100,22 +102,22 @@ def search(query: str) -> List[Dict[str, Any]]:
|
|
100 |
webpage.raise_for_status()
|
101 |
visible_text = extract_text_from_webpage(webpage.text)
|
102 |
if len(visible_text) > max_chars_per_page:
|
103 |
-
visible_text = visible_text[:max_chars_per_page]
|
104 |
all_results.append({"link": link, "text": visible_text})
|
105 |
except requests.exceptions.RequestException:
|
106 |
-
all_results.append({"link": link, "text": "
|
107 |
except requests.exceptions.RequestException as e:
|
108 |
-
all_results.append({"link": "N/A", "text": "
|
109 |
return all_results
|
110 |
|
111 |
def summarize_text(text: str, max_length: int = 150) -> str:
|
112 |
-
"""
|
113 |
conversation = [
|
114 |
-
{"role": "user", "content": f"
|
115 |
]
|
116 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
117 |
input_ids = input_ids.to(device)
|
118 |
-
|
119 |
summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
120 |
summary_kwargs = {
|
121 |
"input_ids": input_ids,
|
@@ -127,40 +129,40 @@ def summarize_text(text: str, max_length: int = 150) -> str:
|
|
127 |
}
|
128 |
t = Thread(target=model.generate, kwargs=summary_kwargs)
|
129 |
t.start()
|
130 |
-
|
131 |
summary = ""
|
132 |
for new_text in summary_streamer:
|
133 |
summary += new_text
|
134 |
return summary
|
135 |
|
136 |
def analyze_sentiment(text: str) -> str:
|
137 |
-
"""
|
138 |
result = sentiment_pipeline(text)
|
139 |
sentiment = result[0]['label']
|
140 |
score = result[0]['score']
|
141 |
-
return f"🟢 **
|
142 |
|
143 |
def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
144 |
"""
|
145 |
-
|
146 |
"""
|
147 |
-
#
|
148 |
conversation = []
|
149 |
for user, assistant in chat_history:
|
150 |
conversation.extend([
|
151 |
{"role": "user", "content": user},
|
152 |
{"role": "assistant", "content": assistant},
|
153 |
])
|
154 |
-
conversation.append({"role": "user", "content": prompt}) #
|
155 |
-
|
156 |
-
#
|
157 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
158 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
159 |
-
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] #
|
160 |
-
gr.Warning(f"
|
161 |
-
input_ids = input_ids.to(device)
|
162 |
-
|
163 |
-
#
|
164 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
165 |
generate_kwargs = {
|
166 |
"input_ids": input_ids,
|
@@ -173,10 +175,10 @@ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_
|
|
173 |
"num_beams": 1,
|
174 |
"repetition_penalty": repetition_penalty,
|
175 |
}
|
176 |
-
t = Thread(target=model.generate, kwargs=generate_kwargs) #
|
177 |
t.start()
|
178 |
-
|
179 |
-
# Stream
|
180 |
outputs = []
|
181 |
for text in streamer:
|
182 |
outputs.append(text)
|
@@ -185,17 +187,17 @@ def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_
|
|
185 |
@lru_cache(maxsize=128)
|
186 |
def process_query(query: str) -> Dict[str, Any]:
|
187 |
"""
|
188 |
-
|
189 |
"""
|
190 |
-
#
|
191 |
-
web_search_keywords = ["
|
192 |
-
general_query_keywords = ["
|
193 |
-
summarize_keywords = ["
|
194 |
-
sentiment_keywords = ["
|
195 |
-
train_keywords = ["
|
196 |
-
|
197 |
-
query_lower = query.lower()
|
198 |
-
|
199 |
if any(keyword in query_lower for keyword in web_search_keywords):
|
200 |
function_name = "web_search"
|
201 |
arguments = {"query": query}
|
@@ -214,7 +216,7 @@ def process_query(query: str) -> Dict[str, Any]:
|
|
214 |
else:
|
215 |
function_name = "hard_query"
|
216 |
arguments = {"prompt": query}
|
217 |
-
|
218 |
return {
|
219 |
"name": function_name,
|
220 |
"arguments": arguments
|
@@ -222,60 +224,60 @@ def process_query(query: str) -> Dict[str, Any]:
|
|
222 |
|
223 |
def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
224 |
"""
|
225 |
-
|
226 |
"""
|
227 |
function_name = function_call["name"]
|
228 |
arguments = function_call["arguments"]
|
229 |
-
|
230 |
if function_name == "web_search":
|
231 |
query = arguments["query"]
|
232 |
-
yield "🔍
|
233 |
web_results = search(query)
|
234 |
if not web_results:
|
235 |
-
yield "⚠️
|
236 |
return
|
237 |
-
#
|
238 |
-
web_summary = '\n\n'.join([f"🔗 **
|
239 |
if not web_summary:
|
240 |
-
web_summary = "⚠️
|
241 |
-
|
242 |
-
#
|
243 |
-
yield "📄 **
|
244 |
-
|
245 |
elif function_name == "summarize_query":
|
246 |
-
#
|
247 |
query = arguments["prompt"]
|
248 |
-
yield "🔍
|
249 |
web_results = search(query)
|
250 |
if not web_results:
|
251 |
-
yield "⚠️
|
252 |
return
|
253 |
-
#
|
254 |
-
combined_text = ' '.join([res['text'] for res in web_results if res['text'] != "
|
255 |
if not combined_text:
|
256 |
-
yield "⚠️
|
257 |
return
|
258 |
-
#
|
259 |
-
yield "📝
|
260 |
summary = summarize_text(combined_text)
|
261 |
-
yield "📄 **
|
262 |
-
|
263 |
elif function_name == "sentiment_analysis":
|
264 |
prompt_text = arguments["prompt"]
|
265 |
-
yield "📊
|
266 |
sentiment = analyze_sentiment(prompt_text)
|
267 |
yield sentiment
|
268 |
-
|
269 |
elif function_name == "train_model":
|
270 |
prompt_text = arguments["prompt"]
|
271 |
-
yield "📊
|
272 |
training_result = run_training()
|
273 |
yield training_result
|
274 |
-
|
275 |
elif function_name in ["general_query", "hard_query"]:
|
276 |
prompt_text = arguments["prompt"]
|
277 |
-
yield "🤖
|
278 |
-
#
|
279 |
response_generator = generate_response(
|
280 |
prompt=prompt_text,
|
281 |
chat_history=chat_history,
|
@@ -287,26 +289,26 @@ def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: L
|
|
287 |
)
|
288 |
for response in response_generator:
|
289 |
yield response
|
290 |
-
|
291 |
else:
|
292 |
-
yield "⚠️
|
293 |
|
294 |
-
# ----------------------------
|
295 |
|
296 |
-
#
|
297 |
CHECKPOINT_DIR = "./checkpoints"
|
298 |
if not os.path.exists(CHECKPOINT_DIR):
|
299 |
os.makedirs(CHECKPOINT_DIR)
|
300 |
|
301 |
-
#
|
302 |
dataset = load_dataset('vntc/wiki-mini-corpus')
|
303 |
|
304 |
-
#
|
305 |
split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
|
306 |
train_dataset = split_dataset['train']
|
307 |
validation_dataset = split_dataset['test']
|
308 |
|
309 |
-
#
|
310 |
def preprocess_function(examples):
|
311 |
passages = [passage.lower().strip() for passage in examples['passage']]
|
312 |
return {'passage': passages}
|
@@ -314,7 +316,7 @@ def preprocess_function(examples):
|
|
314 |
processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
|
315 |
processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
|
316 |
|
317 |
-
#
|
318 |
if tokenizer.pad_token is None:
|
319 |
tokenizer.pad_token = tokenizer.eos_token
|
320 |
|
@@ -324,12 +326,13 @@ def tokenize_function(examples):
|
|
324 |
padding='max_length',
|
325 |
truncation=True,
|
326 |
max_length=512,
|
|
|
327 |
)
|
328 |
|
329 |
tokenized_train = processed_train.map(tokenize_function, batched=True)
|
330 |
tokenized_validation = processed_validation.map(tokenize_function, batched=True)
|
331 |
|
332 |
-
#
|
333 |
def add_labels(examples):
|
334 |
examples['labels'] = examples['input_ids'].copy()
|
335 |
return examples
|
@@ -337,31 +340,31 @@ def add_labels(examples):
|
|
337 |
tokenized_train = tokenized_train.map(add_labels, batched=True)
|
338 |
tokenized_validation = tokenized_validation.map(add_labels, batched=True)
|
339 |
|
340 |
-
#
|
341 |
tokenized_train = tokenized_train.remove_columns(['passage'])
|
342 |
tokenized_validation = tokenized_validation.remove_columns(['passage'])
|
343 |
|
344 |
-
#
|
345 |
tokenized_train.set_format('torch')
|
346 |
tokenized_validation.set_format('torch')
|
347 |
|
348 |
-
#
|
349 |
final_dataset = {
|
350 |
'train': tokenized_train,
|
351 |
'validation': tokenized_validation
|
352 |
}
|
353 |
|
354 |
-
#
|
355 |
class SaveCheckpointCallback(TrainerCallback):
|
356 |
def on_step_end(self, args, state, control, **kwargs):
|
357 |
if state.global_step % args.save_steps == 0 and state.global_step != 0:
|
358 |
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
359 |
-
print(f"
|
360 |
-
trainer = kwargs['trainer'] #
|
361 |
trainer.save_model(checkpoint_path)
|
362 |
-
return control #
|
363 |
|
364 |
-
#
|
365 |
pretrained = AutoModelForCausalLM.from_pretrained(
|
366 |
model_id,
|
367 |
device_map="auto",
|
@@ -371,31 +374,30 @@ pretrained = AutoModelForCausalLM.from_pretrained(
|
|
371 |
|
372 |
data_collator = DataCollatorForLanguageModeling(
|
373 |
tokenizer=tokenizer,
|
374 |
-
mlm=False, #
|
375 |
pad_to_multiple_of=8
|
376 |
)
|
377 |
|
378 |
def get_step_done() -> int:
|
379 |
"""
|
380 |
-
|
381 |
|
382 |
Returns:
|
383 |
-
int:
|
384 |
"""
|
385 |
checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')]
|
386 |
if not checkpoints:
|
387 |
return 0
|
388 |
try:
|
389 |
-
#
|
390 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[1]))
|
391 |
step_done = int(latest_checkpoint.split('-')[1])
|
392 |
return step_done
|
393 |
except (IndexError, ValueError) as e:
|
394 |
-
print(f"
|
395 |
return 0
|
396 |
|
397 |
-
|
398 |
-
# Tải và Cấu Hình Mô Hình với LoRA (GPU)
|
399 |
lora_config = LoraConfig(
|
400 |
r=8,
|
401 |
lora_alpha=32,
|
@@ -410,34 +412,34 @@ print(pretrained_model)
|
|
410 |
@spaces.GPU(duration=30, queue=False)
|
411 |
def run_training() -> str:
|
412 |
"""
|
413 |
-
|
414 |
|
415 |
Returns:
|
416 |
-
str:
|
417 |
"""
|
418 |
|
419 |
-
#
|
420 |
training_args = TrainingArguments(
|
421 |
output_dir=CHECKPOINT_DIR,
|
422 |
per_device_train_batch_size=4,
|
423 |
per_device_eval_batch_size=4,
|
424 |
gradient_accumulation_steps=8,
|
425 |
num_train_epochs=3,
|
426 |
-
max_steps=300, #
|
427 |
learning_rate=3e-4,
|
428 |
weight_decay=0.01,
|
429 |
-
logging_steps=1, #
|
430 |
-
eval_strategy="steps", #
|
431 |
-
eval_steps=5, #
|
432 |
-
save_strategy="steps", #
|
433 |
-
save_steps=5, #
|
434 |
-
save_total_limit=5, #
|
435 |
fp16=True,
|
436 |
report_to="none",
|
437 |
load_best_model_at_end=True,
|
438 |
)
|
439 |
|
440 |
-
#
|
441 |
trainer = Trainer(
|
442 |
model=pretrained_model,
|
443 |
args=training_args,
|
@@ -445,50 +447,50 @@ def run_training() -> str:
|
|
445 |
eval_dataset=final_dataset['validation'],
|
446 |
tokenizer=tokenizer,
|
447 |
data_collator=data_collator,
|
448 |
-
callbacks=[SaveCheckpointCallback()], #
|
449 |
)
|
450 |
|
451 |
-
#
|
452 |
steps_done = get_step_done()
|
453 |
if steps_done > 0:
|
454 |
-
#
|
455 |
latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
|
456 |
if os.path.exists(latest_checkpoint):
|
457 |
-
print(f"
|
458 |
trainer.train(resume_from_checkpoint=latest_checkpoint)
|
459 |
else:
|
460 |
-
print(f"Checkpoint {latest_checkpoint}
|
461 |
trainer.train()
|
462 |
else:
|
463 |
trainer.train()
|
464 |
|
465 |
-
#
|
466 |
trainer.save_model(CHECKPOINT_DIR)
|
467 |
-
return "
|
468 |
|
469 |
-
#
|
470 |
@spaces.GPU(duration=30, queue=False)
|
471 |
def continuous_training(total_steps=300, steps_per_call=50):
|
472 |
"""
|
473 |
-
|
474 |
|
475 |
Args:
|
476 |
-
total_steps (int):
|
477 |
-
steps_per_call (int):
|
478 |
"""
|
479 |
steps_done = get_step_done()
|
480 |
while steps_done < total_steps:
|
481 |
remaining_steps = total_steps - steps_done
|
482 |
current_steps = min(steps_per_call, remaining_steps)
|
483 |
-
print(f"
|
484 |
|
485 |
-
#
|
486 |
training_args = TrainingArguments(
|
487 |
output_dir=CHECKPOINT_DIR,
|
488 |
per_device_train_batch_size=4,
|
489 |
per_device_eval_batch_size=4,
|
490 |
gradient_accumulation_steps=8,
|
491 |
-
num_train_epochs=1, #
|
492 |
max_steps=current_steps,
|
493 |
learning_rate=3e-4,
|
494 |
weight_decay=0.01,
|
@@ -503,7 +505,7 @@ def continuous_training(total_steps=300, steps_per_call=50):
|
|
503 |
load_best_model_at_end=True,
|
504 |
)
|
505 |
|
506 |
-
#
|
507 |
trainer = Trainer(
|
508 |
model=pretrained_model,
|
509 |
args=training_args,
|
@@ -514,30 +516,30 @@ def continuous_training(total_steps=300, steps_per_call=50):
|
|
514 |
callbacks=[SaveCheckpointCallback()],
|
515 |
)
|
516 |
|
517 |
-
#
|
518 |
if steps_done > 0:
|
519 |
latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
|
520 |
if os.path.exists(latest_checkpoint):
|
521 |
-
print(f"
|
522 |
trainer.train(resume_from_checkpoint=latest_checkpoint)
|
523 |
else:
|
524 |
-
print(f"Checkpoint {latest_checkpoint}
|
525 |
trainer.train()
|
526 |
else:
|
527 |
trainer.train()
|
528 |
|
529 |
steps_done = get_step_done()
|
530 |
-
print(f"
|
531 |
|
532 |
-
#
|
533 |
if steps_done >= total_steps:
|
534 |
-
print("
|
535 |
break
|
536 |
|
537 |
-
#
|
538 |
-
time.sleep(2) #
|
539 |
|
540 |
-
# ----------------------------
|
541 |
|
542 |
@spaces.GPU(duration=30, queue=False)
|
543 |
def generate(
|
@@ -550,29 +552,29 @@ def generate(
|
|
550 |
repetition_penalty: float = 1.2,
|
551 |
) -> Iterator[str]:
|
552 |
"""
|
553 |
-
|
554 |
"""
|
555 |
-
#
|
556 |
-
yield "🔍
|
557 |
|
558 |
-
#
|
559 |
function_call = process_query(message)
|
560 |
|
561 |
-
#
|
562 |
if function_call["name"] == "web_search":
|
563 |
-
yield "🛠️
|
564 |
elif function_call["name"] == "summarize_query":
|
565 |
-
yield "🛠️
|
566 |
elif function_call["name"] == "sentiment_analysis":
|
567 |
-
yield "🛠️
|
568 |
elif function_call["name"] in ["general_query", "hard_query"]:
|
569 |
-
yield "🛠️
|
570 |
elif function_call["name"] == "train_model":
|
571 |
-
yield "🛠️
|
572 |
else:
|
573 |
-
yield "⚠️
|
574 |
|
575 |
-
#
|
576 |
response_iterator = handle_functions(
|
577 |
function_call=function_call,
|
578 |
prompt=message,
|
@@ -587,40 +589,40 @@ def generate(
|
|
587 |
for response in response_iterator:
|
588 |
yield response
|
589 |
|
590 |
-
#
|
591 |
EXAMPLES = [
|
592 |
-
["
|
593 |
-
["
|
594 |
-
["
|
595 |
-
["
|
596 |
-
["
|
597 |
-
["
|
598 |
-
["
|
599 |
-
["
|
600 |
-
["
|
601 |
-
["
|
602 |
]
|
603 |
|
604 |
-
#
|
605 |
chat_interface = gr.ChatInterface(
|
606 |
-
fn=generate, #
|
607 |
additional_inputs=[
|
608 |
gr.Slider(
|
609 |
-
label="
|
610 |
minimum=1,
|
611 |
maximum=MAX_MAX_NEW_TOKENS,
|
612 |
step=1,
|
613 |
value=DEFAULT_MAX_NEW_TOKENS,
|
614 |
),
|
615 |
gr.Slider(
|
616 |
-
label="
|
617 |
minimum=0.1,
|
618 |
maximum=4.0,
|
619 |
step=0.1,
|
620 |
value=0.6,
|
621 |
),
|
622 |
gr.Slider(
|
623 |
-
label="Top-p (
|
624 |
minimum=0.05,
|
625 |
maximum=1.0,
|
626 |
step=0.05,
|
@@ -634,47 +636,45 @@ chat_interface = gr.ChatInterface(
|
|
634 |
value=50,
|
635 |
),
|
636 |
gr.Slider(
|
637 |
-
label="
|
638 |
minimum=1.0,
|
639 |
maximum=2.0,
|
640 |
step=0.05,
|
641 |
value=1.2,
|
642 |
),
|
643 |
],
|
644 |
-
stop_btn=None, #
|
645 |
-
examples=EXAMPLES, #
|
646 |
-
cache_examples=False, #
|
647 |
title="🤖 OpenGPT-4o Chatbot",
|
648 |
-
description="
|
649 |
-
theme="default", #
|
650 |
)
|
651 |
|
652 |
-
#
|
653 |
with gr.Blocks(css="""
|
654 |
.gradio-container {
|
655 |
-
background-color: #f0f2f5; /*
|
656 |
}
|
657 |
.gradio-container h1 {
|
658 |
-
color: #4a90e2; /*
|
659 |
}
|
660 |
.gradio-container .gr-button {
|
661 |
-
background-color: #4a90e2; /*
|
662 |
-
color: white; /*
|
663 |
}
|
664 |
.gradio-container .gr-slider__label {
|
665 |
-
color: #333333; /*
|
666 |
}
|
667 |
.gradio-container .gr-chatbot {
|
668 |
-
border: 2px solid #4a90e2; /*
|
669 |
-
border-radius: 10px; /*
|
670 |
-
padding: 10px; /*
|
671 |
-
background-color: #ffffff; /*
|
672 |
}
|
673 |
""", fill_height=True) as demo:
|
674 |
-
gr.Markdown(DESCRIPTION) #
|
675 |
-
#
|
676 |
-
# gr.DuplicateButton(value="Nhân bản Không gian để sử dụng riêng tư", elem_id="duplicate-button") # Uncomment if gr.DuplicateButton is needed
|
677 |
-
chat_interface.render() # Hiển thị giao diện trò chuyện
|
678 |
|
679 |
if __name__ == "__main__":
|
680 |
-
demo.queue(max_size=30).launch() #
|
|
|
23 |
from peft import LoraConfig, get_peft_model
|
24 |
import time
|
25 |
|
26 |
+
# ---------------------------- Configuration ---------------------------- #
|
27 |
|
28 |
DESCRIPTION = """\
|
29 |
+
# Llama 3.2 3B Instruct with Advanced Features
|
30 |
|
31 |
+
Llama 3.2 3B is the latest version from Meta for open language models.
|
32 |
+
This demo showcases [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction-following.
|
33 |
+
For more details, please see [our blog post](https://huggingface.co/blog/llama32).
|
34 |
"""
|
35 |
|
36 |
+
MAX_MAX_NEW_TOKENS = 2048 # Maximum tokens to generate
|
37 |
+
DEFAULT_MAX_NEW_TOKENS = 1024 # Default tokens to generate
|
38 |
+
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000")) # Max input token length
|
39 |
|
40 |
+
# Determine device (GPU if available, else CPU)
|
41 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
42 |
|
43 |
+
model_id = "meta-llama/Llama-3.2-3B-Instruct" # Model ID
|
44 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
45 |
model = AutoModelForCausalLM.from_pretrained(
|
46 |
model_id,
|
47 |
device_map="auto",
|
48 |
+
torch_dtype=torch.float16, # Use float16 for compatibility with fp16=True
|
49 |
+
)
|
50 |
+
model.to(device)
|
51 |
+
model.eval()
|
52 |
+
|
53 |
+
# Initialize sentiment analysis pipeline on GPU if available
|
54 |
+
sentiment_pipeline = pipeline(
|
55 |
+
"sentiment-analysis",
|
56 |
+
model="nlptown/bert-base-multilingual-uncased-sentiment",
|
57 |
+
device=0 if torch.cuda.is_available() else -1
|
58 |
)
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
# ---------------------------- Function Definitions ---------------------------- #
|
61 |
|
62 |
@lru_cache(maxsize=128)
|
63 |
def extract_text_from_webpage(html_content: str) -> str:
|
64 |
+
"""Extract visible text from HTML content using BeautifulSoup."""
|
65 |
soup = BeautifulSoup(html_content, "html.parser")
|
|
|
66 |
for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
|
67 |
tag.extract()
|
|
|
68 |
visible_text = soup.get_text(separator=' ', strip=True)
|
69 |
return visible_text
|
70 |
|
71 |
def search(query: str) -> List[Dict[str, Any]]:
|
72 |
+
"""Perform a Google search and return results."""
|
73 |
term = query
|
74 |
all_results = []
|
75 |
+
max_chars_per_page = 8000 # Max characters per page
|
76 |
headers = {
|
77 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
|
78 |
}
|
|
|
81 |
resp = session.get(
|
82 |
url="https://www.google.com/search",
|
83 |
headers=headers,
|
84 |
+
params={"q": term, "num": 4}, # 4 results per page
|
85 |
timeout=5,
|
86 |
+
verify=False, # Skip SSL verification
|
87 |
)
|
88 |
+
resp.raise_for_status()
|
89 |
soup = BeautifulSoup(resp.text, "html.parser")
|
90 |
+
result_blocks = soup.find_all("div", attrs={"class": "g"})
|
91 |
for result in result_blocks:
|
92 |
+
link_tag = result.find("a", href=True)
|
93 |
if link_tag and 'href' in link_tag.attrs:
|
94 |
link = link_tag["href"]
|
95 |
try:
|
|
|
102 |
webpage.raise_for_status()
|
103 |
visible_text = extract_text_from_webpage(webpage.text)
|
104 |
if len(visible_text) > max_chars_per_page:
|
105 |
+
visible_text = visible_text[:max_chars_per_page]
|
106 |
all_results.append({"link": link, "text": visible_text})
|
107 |
except requests.exceptions.RequestException:
|
108 |
+
all_results.append({"link": link, "text": "Could not retrieve content."})
|
109 |
except requests.exceptions.RequestException as e:
|
110 |
+
all_results.append({"link": "N/A", "text": "Could not perform search."})
|
111 |
return all_results
|
112 |
|
113 |
def summarize_text(text: str, max_length: int = 150) -> str:
|
114 |
+
"""Summarize text using the Llama model."""
|
115 |
conversation = [
|
116 |
+
{"role": "user", "content": f"Please summarize the following text: {text}"}
|
117 |
]
|
118 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
119 |
input_ids = input_ids.to(device)
|
120 |
+
|
121 |
summary_streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
122 |
summary_kwargs = {
|
123 |
"input_ids": input_ids,
|
|
|
129 |
}
|
130 |
t = Thread(target=model.generate, kwargs=summary_kwargs)
|
131 |
t.start()
|
132 |
+
|
133 |
summary = ""
|
134 |
for new_text in summary_streamer:
|
135 |
summary += new_text
|
136 |
return summary
|
137 |
|
138 |
def analyze_sentiment(text: str) -> str:
|
139 |
+
"""Analyze sentiment of the text using a sentiment analysis model."""
|
140 |
result = sentiment_pipeline(text)
|
141 |
sentiment = result[0]['label']
|
142 |
score = result[0]['score']
|
143 |
+
return f"🟢 **Sentiment**: {sentiment} (Score: {score:.2f})"
|
144 |
|
145 |
def generate_response(prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
146 |
"""
|
147 |
+
Generate a response using the Llama model in streaming mode.
|
148 |
"""
|
149 |
+
# Build conversation history
|
150 |
conversation = []
|
151 |
for user, assistant in chat_history:
|
152 |
conversation.extend([
|
153 |
{"role": "user", "content": user},
|
154 |
{"role": "assistant", "content": assistant},
|
155 |
])
|
156 |
+
conversation.append({"role": "user", "content": prompt}) # Add user's message
|
157 |
+
|
158 |
+
# Prepare input_ids from tokenizer
|
159 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
160 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
161 |
+
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # Truncate input if too long
|
162 |
+
gr.Warning(f"Truncated conversation due to exceeding {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
163 |
+
input_ids = input_ids.to(device)
|
164 |
+
|
165 |
+
# Initialize streamer for real-time text generation
|
166 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
167 |
generate_kwargs = {
|
168 |
"input_ids": input_ids,
|
|
|
175 |
"num_beams": 1,
|
176 |
"repetition_penalty": repetition_penalty,
|
177 |
}
|
178 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs) # Create thread for text generation
|
179 |
t.start()
|
180 |
+
|
181 |
+
# Stream generated text
|
182 |
outputs = []
|
183 |
for text in streamer:
|
184 |
outputs.append(text)
|
|
|
187 |
@lru_cache(maxsize=128)
|
188 |
def process_query(query: str) -> Dict[str, Any]:
|
189 |
"""
|
190 |
+
Determine which function to call based on the user's query.
|
191 |
"""
|
192 |
+
# Define keywords or patterns to identify functions
|
193 |
+
web_search_keywords = ["search", "find", "lookup", "google"]
|
194 |
+
general_query_keywords = ["explain", "describe", "tell me about", "what is", "how to"]
|
195 |
+
summarize_keywords = ["summarize", "summarise", "brief", "short"]
|
196 |
+
sentiment_keywords = ["emotion", "mood", "sentiment", "analyze sentiment"]
|
197 |
+
train_keywords = ["train"]
|
198 |
+
|
199 |
+
query_lower = query.lower()
|
200 |
+
|
201 |
if any(keyword in query_lower for keyword in web_search_keywords):
|
202 |
function_name = "web_search"
|
203 |
arguments = {"query": query}
|
|
|
216 |
else:
|
217 |
function_name = "hard_query"
|
218 |
arguments = {"prompt": query}
|
219 |
+
|
220 |
return {
|
221 |
"name": function_name,
|
222 |
"arguments": arguments
|
|
|
224 |
|
225 |
def handle_functions(function_call: Dict[str, Any], prompt: str, chat_history: List[Tuple[str, str]], max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float) -> Iterator[str]:
|
226 |
"""
|
227 |
+
Execute the appropriate function based on the function call.
|
228 |
"""
|
229 |
function_name = function_call["name"]
|
230 |
arguments = function_call["arguments"]
|
231 |
+
|
232 |
if function_name == "web_search":
|
233 |
query = arguments["query"]
|
234 |
+
yield "🔍 Performing web search..."
|
235 |
web_results = search(query)
|
236 |
if not web_results:
|
237 |
+
yield "⚠️ No results found."
|
238 |
return
|
239 |
+
# Summarize search results
|
240 |
+
web_summary = '\n\n'.join([f"🔗 **Link**: {res['link']}\n📝 **Description**: {res['text']}" for res in web_results if res["text"] != "Could not retrieve content."])
|
241 |
if not web_summary:
|
242 |
+
web_summary = "⚠️ Could not retrieve content from search results."
|
243 |
+
|
244 |
+
# Return search results to user
|
245 |
+
yield "📄 **Search Results:**\n" + web_summary
|
246 |
+
|
247 |
elif function_name == "summarize_query":
|
248 |
+
# When user requests summarization, perform search and then summarize
|
249 |
query = arguments["prompt"]
|
250 |
+
yield "🔍 Performing search for summarization..."
|
251 |
web_results = search(query)
|
252 |
if not web_results:
|
253 |
+
yield "⚠️ No results found to summarize."
|
254 |
return
|
255 |
+
# Combine content from search results for summarization
|
256 |
+
combined_text = ' '.join([res['text'] for res in web_results if res['text'] != "Could not retrieve content."])
|
257 |
if not combined_text:
|
258 |
+
yield "⚠️ No content available to summarize."
|
259 |
return
|
260 |
+
# Summarize the combined content
|
261 |
+
yield "📝 Summarizing information..."
|
262 |
summary = summarize_text(combined_text)
|
263 |
+
yield "📄 **Summary:**\n" + summary
|
264 |
+
|
265 |
elif function_name == "sentiment_analysis":
|
266 |
prompt_text = arguments["prompt"]
|
267 |
+
yield "📊 Analyzing sentiment..."
|
268 |
sentiment = analyze_sentiment(prompt_text)
|
269 |
yield sentiment
|
270 |
+
|
271 |
elif function_name == "train_model":
|
272 |
prompt_text = arguments["prompt"]
|
273 |
+
yield "📊 Training the model..."
|
274 |
training_result = run_training()
|
275 |
yield training_result
|
276 |
+
|
277 |
elif function_name in ["general_query", "hard_query"]:
|
278 |
prompt_text = arguments["prompt"]
|
279 |
+
yield "🤖 Generating response..."
|
280 |
+
# Generate response using the Llama model
|
281 |
response_generator = generate_response(
|
282 |
prompt=prompt_text,
|
283 |
chat_history=chat_history,
|
|
|
289 |
)
|
290 |
for response in response_generator:
|
291 |
yield response
|
292 |
+
|
293 |
else:
|
294 |
+
yield "⚠️ Unrecognized function call."
|
295 |
|
296 |
+
# ---------------------------- Training Setup ---------------------------- #
|
297 |
|
298 |
+
# Checkpoint directory
|
299 |
CHECKPOINT_DIR = "./checkpoints"
|
300 |
if not os.path.exists(CHECKPOINT_DIR):
|
301 |
os.makedirs(CHECKPOINT_DIR)
|
302 |
|
303 |
+
# Load Dataset (CPU)
|
304 |
dataset = load_dataset('vntc/wiki-mini-corpus')
|
305 |
|
306 |
+
# Split Dataset into train and validation (CPU)
|
307 |
split_dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
|
308 |
train_dataset = split_dataset['train']
|
309 |
validation_dataset = split_dataset['test']
|
310 |
|
311 |
+
# Text Preprocessing (CPU)
|
312 |
def preprocess_function(examples):
|
313 |
passages = [passage.lower().strip() for passage in examples['passage']]
|
314 |
return {'passage': passages}
|
|
|
316 |
processed_train = train_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
|
317 |
processed_validation = validation_dataset.map(preprocess_function, batched=True, remove_columns=['id', 'metadata'])
|
318 |
|
319 |
+
# Ensure tokenizer has pad_token
|
320 |
if tokenizer.pad_token is None:
|
321 |
tokenizer.pad_token = tokenizer.eos_token
|
322 |
|
|
|
326 |
padding='max_length',
|
327 |
truncation=True,
|
328 |
max_length=512,
|
329 |
+
return_tensors="pt"
|
330 |
)
|
331 |
|
332 |
tokenized_train = processed_train.map(tokenize_function, batched=True)
|
333 |
tokenized_validation = processed_validation.map(tokenize_function, batched=True)
|
334 |
|
335 |
+
# Add 'labels' field (CPU)
|
336 |
def add_labels(examples):
|
337 |
examples['labels'] = examples['input_ids'].copy()
|
338 |
return examples
|
|
|
340 |
tokenized_train = tokenized_train.map(add_labels, batched=True)
|
341 |
tokenized_validation = tokenized_validation.map(add_labels, batched=True)
|
342 |
|
343 |
+
# Remove unnecessary columns (CPU)
|
344 |
tokenized_train = tokenized_train.remove_columns(['passage'])
|
345 |
tokenized_validation = tokenized_validation.remove_columns(['passage'])
|
346 |
|
347 |
+
# Set format for PyTorch (CPU)
|
348 |
tokenized_train.set_format('torch')
|
349 |
tokenized_validation.set_format('torch')
|
350 |
|
351 |
+
# Create DatasetDict (CPU)
|
352 |
final_dataset = {
|
353 |
'train': tokenized_train,
|
354 |
'validation': tokenized_validation
|
355 |
}
|
356 |
|
357 |
+
# Define TrainerCallback to Save Checkpoints Faster
|
358 |
class SaveCheckpointCallback(TrainerCallback):
|
359 |
def on_step_end(self, args, state, control, **kwargs):
|
360 |
if state.global_step % args.save_steps == 0 and state.global_step != 0:
|
361 |
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
362 |
+
print(f"Saving checkpoint at: {checkpoint_path}")
|
363 |
+
trainer = kwargs['trainer'] # Access trainer from kwargs
|
364 |
trainer.save_model(checkpoint_path)
|
365 |
+
return control # Return current control object
|
366 |
|
367 |
+
# Load pretrained model with LoRA
|
368 |
pretrained = AutoModelForCausalLM.from_pretrained(
|
369 |
model_id,
|
370 |
device_map="auto",
|
|
|
374 |
|
375 |
data_collator = DataCollatorForLanguageModeling(
|
376 |
tokenizer=tokenizer,
|
377 |
+
mlm=False, # Causal LM
|
378 |
pad_to_multiple_of=8
|
379 |
)
|
380 |
|
381 |
def get_step_done() -> int:
|
382 |
"""
|
383 |
+
Get the number of training steps completed from the latest checkpoint.
|
384 |
|
385 |
Returns:
|
386 |
+
int: Number of steps completed. Returns 0 if no checkpoint is found.
|
387 |
"""
|
388 |
checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')]
|
389 |
if not checkpoints:
|
390 |
return 0
|
391 |
try:
|
392 |
+
# Find the latest checkpoint based on step number
|
393 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[1]))
|
394 |
step_done = int(latest_checkpoint.split('-')[1])
|
395 |
return step_done
|
396 |
except (IndexError, ValueError) as e:
|
397 |
+
print(f"Error parsing checkpoint name: {e}")
|
398 |
return 0
|
399 |
|
400 |
+
# Load and Configure LoRA (GPU)
|
|
|
401 |
lora_config = LoraConfig(
|
402 |
r=8,
|
403 |
lora_alpha=32,
|
|
|
412 |
@spaces.GPU(duration=30, queue=False)
|
413 |
def run_training() -> str:
|
414 |
"""
|
415 |
+
Train the model using GPU with time constraints.
|
416 |
|
417 |
Returns:
|
418 |
+
str: Training result message.
|
419 |
"""
|
420 |
|
421 |
+
# TrainingArguments Configuration (GPU)
|
422 |
training_args = TrainingArguments(
|
423 |
output_dir=CHECKPOINT_DIR,
|
424 |
per_device_train_batch_size=4,
|
425 |
per_device_eval_batch_size=4,
|
426 |
gradient_accumulation_steps=8,
|
427 |
num_train_epochs=3,
|
428 |
+
max_steps=300, # Total training steps
|
429 |
learning_rate=3e-4,
|
430 |
weight_decay=0.01,
|
431 |
+
logging_steps=1, # Log every step
|
432 |
+
eval_strategy="steps", # Evaluate every few steps
|
433 |
+
eval_steps=5, # Evaluate every 5 steps
|
434 |
+
save_strategy="steps", # Save checkpoint every few steps
|
435 |
+
save_steps=5, # Save every 5 steps
|
436 |
+
save_total_limit=5, # Limit number of saved checkpoints
|
437 |
fp16=True,
|
438 |
report_to="none",
|
439 |
load_best_model_at_end=True,
|
440 |
)
|
441 |
|
442 |
+
# Initialize Trainer (GPU)
|
443 |
trainer = Trainer(
|
444 |
model=pretrained_model,
|
445 |
args=training_args,
|
|
|
447 |
eval_dataset=final_dataset['validation'],
|
448 |
tokenizer=tokenizer,
|
449 |
data_collator=data_collator,
|
450 |
+
callbacks=[SaveCheckpointCallback()], # Add callback
|
451 |
)
|
452 |
|
453 |
+
# Check for existing checkpoint
|
454 |
steps_done = get_step_done()
|
455 |
if steps_done > 0:
|
456 |
+
# Determine the latest checkpoint based on step number
|
457 |
latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
|
458 |
if os.path.exists(latest_checkpoint):
|
459 |
+
print(f"Resuming training from checkpoint: {latest_checkpoint}")
|
460 |
trainer.train(resume_from_checkpoint=latest_checkpoint)
|
461 |
else:
|
462 |
+
print(f"Checkpoint {latest_checkpoint} does not exist. Starting training from scratch.")
|
463 |
trainer.train()
|
464 |
else:
|
465 |
trainer.train()
|
466 |
|
467 |
+
# Save checkpoint after training
|
468 |
trainer.save_model(CHECKPOINT_DIR)
|
469 |
+
return "Training completed or resumed from checkpoint."
|
470 |
|
471 |
+
# Automatic Function to Call Training Repeatedly
|
472 |
@spaces.GPU(duration=30, queue=False)
|
473 |
def continuous_training(total_steps=300, steps_per_call=50):
|
474 |
"""
|
475 |
+
Automatically call `run_training` to complete the training process.
|
476 |
|
477 |
Args:
|
478 |
+
total_steps (int): Desired total training steps.
|
479 |
+
steps_per_call (int): Training steps per function call.
|
480 |
"""
|
481 |
steps_done = get_step_done()
|
482 |
while steps_done < total_steps:
|
483 |
remaining_steps = total_steps - steps_done
|
484 |
current_steps = min(steps_per_call, remaining_steps)
|
485 |
+
print(f"Starting training for {current_steps} steps.")
|
486 |
|
487 |
+
# Update TrainingArguments for current_steps
|
488 |
training_args = TrainingArguments(
|
489 |
output_dir=CHECKPOINT_DIR,
|
490 |
per_device_train_batch_size=4,
|
491 |
per_device_eval_batch_size=4,
|
492 |
gradient_accumulation_steps=8,
|
493 |
+
num_train_epochs=1, # Train for one epoch
|
494 |
max_steps=current_steps,
|
495 |
learning_rate=3e-4,
|
496 |
weight_decay=0.01,
|
|
|
505 |
load_best_model_at_end=True,
|
506 |
)
|
507 |
|
508 |
+
# Initialize Trainer with updated TrainingArguments
|
509 |
trainer = Trainer(
|
510 |
model=pretrained_model,
|
511 |
args=training_args,
|
|
|
516 |
callbacks=[SaveCheckpointCallback()],
|
517 |
)
|
518 |
|
519 |
+
# Resume training from latest checkpoint
|
520 |
if steps_done > 0:
|
521 |
latest_checkpoint = os.path.join(CHECKPOINT_DIR, f"checkpoint-{steps_done}")
|
522 |
if os.path.exists(latest_checkpoint):
|
523 |
+
print(f"Resuming training from checkpoint: {latest_checkpoint}")
|
524 |
trainer.train(resume_from_checkpoint=latest_checkpoint)
|
525 |
else:
|
526 |
+
print(f"Checkpoint {latest_checkpoint} does not exist. Starting training from scratch.")
|
527 |
trainer.train()
|
528 |
else:
|
529 |
trainer.train()
|
530 |
|
531 |
steps_done = get_step_done()
|
532 |
+
print(f"Trained {steps_done} / {total_steps} steps.")
|
533 |
|
534 |
+
# Check if desired steps are achieved
|
535 |
if steps_done >= total_steps:
|
536 |
+
print("Completed the entire training process.")
|
537 |
break
|
538 |
|
539 |
+
# Wait before the next training call
|
540 |
+
time.sleep(2) # Adjust wait time as needed
|
541 |
|
542 |
+
# ---------------------------- Gradio Interface ---------------------------- #
|
543 |
|
544 |
@spaces.GPU(duration=30, queue=False)
|
545 |
def generate(
|
|
|
552 |
repetition_penalty: float = 1.2,
|
553 |
) -> Iterator[str]:
|
554 |
"""
|
555 |
+
Main function to handle user input and generate responses.
|
556 |
"""
|
557 |
+
# Notify about query analysis
|
558 |
+
yield "🔍 Analyzing your query..."
|
559 |
|
560 |
+
# Determine which function to call based on user's message
|
561 |
function_call = process_query(message)
|
562 |
|
563 |
+
# Notify about the selected function
|
564 |
if function_call["name"] == "web_search":
|
565 |
+
yield "🛠️ Selected function: Web Search."
|
566 |
elif function_call["name"] == "summarize_query":
|
567 |
+
yield "🛠️ Selected function: Text Summarization."
|
568 |
elif function_call["name"] == "sentiment_analysis":
|
569 |
+
yield "🛠️ Selected function: Sentiment Analysis."
|
570 |
elif function_call["name"] in ["general_query", "hard_query"]:
|
571 |
+
yield "🛠️ Selected function: Answering Questions."
|
572 |
elif function_call["name"] == "train_model":
|
573 |
+
yield "🛠️ Selected function: Model Training."
|
574 |
else:
|
575 |
+
yield "⚠️ Could not determine an appropriate function."
|
576 |
|
577 |
+
# Execute the function call and generate responses
|
578 |
response_iterator = handle_functions(
|
579 |
function_call=function_call,
|
580 |
prompt=message,
|
|
|
589 |
for response in response_iterator:
|
590 |
yield response
|
591 |
|
592 |
+
# Define examples to guide users
|
593 |
EXAMPLES = [
|
594 |
+
["Hello! How are you?"],
|
595 |
+
["Can you briefly explain the Python programming language?"],
|
596 |
+
["Explain the plot of Cinderella in one sentence."],
|
597 |
+
["How many hours does a man need to eat a helicopter?"],
|
598 |
+
["Write a 100-word article on 'Benefits of Open Source in AI Research'"],
|
599 |
+
["Search and provide me with the latest news on renewable energy."],
|
600 |
+
["Find information about the Great Barrier Reef coral reefs."],
|
601 |
+
["Summarize information about artificial intelligence."],
|
602 |
+
["Analyze the sentiment of the following text: I am very happy to meet you today!"],
|
603 |
+
["Train the model!"],
|
604 |
]
|
605 |
|
606 |
+
# Configure Gradio chat interface with enhanced UI
|
607 |
chat_interface = gr.ChatInterface(
|
608 |
+
fn=generate, # Function called on user interaction
|
609 |
additional_inputs=[
|
610 |
gr.Slider(
|
611 |
+
label="Max New Tokens",
|
612 |
minimum=1,
|
613 |
maximum=MAX_MAX_NEW_TOKENS,
|
614 |
step=1,
|
615 |
value=DEFAULT_MAX_NEW_TOKENS,
|
616 |
),
|
617 |
gr.Slider(
|
618 |
+
label="Temperature",
|
619 |
minimum=0.1,
|
620 |
maximum=4.0,
|
621 |
step=0.1,
|
622 |
value=0.6,
|
623 |
),
|
624 |
gr.Slider(
|
625 |
+
label="Top-p (Nucleus Sampling)",
|
626 |
minimum=0.05,
|
627 |
maximum=1.0,
|
628 |
step=0.05,
|
|
|
636 |
value=50,
|
637 |
),
|
638 |
gr.Slider(
|
639 |
+
label="Repetition Penalty",
|
640 |
minimum=1.0,
|
641 |
maximum=2.0,
|
642 |
step=0.05,
|
643 |
value=1.2,
|
644 |
),
|
645 |
],
|
646 |
+
stop_btn=None, # No stop button
|
647 |
+
examples=EXAMPLES, # Display examples to users
|
648 |
+
cache_examples=False, # Do not cache examples
|
649 |
title="🤖 OpenGPT-4o Chatbot",
|
650 |
+
description="A powerful AI assistant using the local Llama-3.2 model with web search, text summarization, and sentiment analysis functionalities.",
|
651 |
+
theme="default", # Customize theme as needed
|
652 |
)
|
653 |
|
654 |
+
# Create the main Gradio interface with custom CSS
|
655 |
with gr.Blocks(css="""
|
656 |
.gradio-container {
|
657 |
+
background-color: #f0f2f5; /* Light background color */
|
658 |
}
|
659 |
.gradio-container h1 {
|
660 |
+
color: #4a90e2; /* Blue color for title */
|
661 |
}
|
662 |
.gradio-container .gr-button {
|
663 |
+
background-color: #4a90e2; /* Blue color for buttons */
|
664 |
+
color: white; /* White text on buttons */
|
665 |
}
|
666 |
.gradio-container .gr-slider__label {
|
667 |
+
color: #333333; /* Dark text for slider labels */
|
668 |
}
|
669 |
.gradio-container .gr-chatbot {
|
670 |
+
border: 2px solid #4a90e2; /* Blue border for chatbot */
|
671 |
+
border-radius: 10px; /* Rounded corners for chatbot */
|
672 |
+
padding: 10px; /* Inner padding for chatbot */
|
673 |
+
background-color: #ffffff; /* White background for chatbot */
|
674 |
}
|
675 |
""", fill_height=True) as demo:
|
676 |
+
gr.Markdown(DESCRIPTION) # Display description
|
677 |
+
chat_interface.render() # Render chat interface
|
|
|
|
|
678 |
|
679 |
if __name__ == "__main__":
|
680 |
+
demo.queue(max_size=30).launch() # Launch Gradio app with a queue size of 30
|