hoduyquocbao commited on
Commit
1655dfc
·
1 Parent(s): a75374c

update spaces.GPU

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -3,6 +3,7 @@ from threading import Thread
3
  from typing import Iterator, List, Tuple, Dict, Any
4
 
5
  import gradio as gr
 
6
  import torch
7
  from transformers import (
8
  TrainingArguments,
@@ -406,6 +407,7 @@ lora_config = LoraConfig(
406
  pretrained_model = get_peft_model(pretrained, lora_config)
407
  print(pretrained_model)
408
 
 
409
  def run_training() -> str:
410
  """
411
  Hàm huấn luyện mô hình sử dụng GPU với thời gian hạn chế.
@@ -465,6 +467,7 @@ def run_training() -> str:
465
  return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
466
 
467
  # Hàm Tự Động Hóa Việc Gọi Lặp Lại Hàm Huấn Luyện
 
468
  def continuous_training(total_steps=300, steps_per_call=50):
469
  """
470
  Hàm tự động gọi lại `run_training` để hoàn thành quá trình huấn luyện.
@@ -536,6 +539,7 @@ def continuous_training(total_steps=300, steps_per_call=50):
536
 
537
  # ---------------------------- Giao Diện Gradio ---------------------------- #
538
 
 
539
  def generate(
540
  message: str,
541
  chat_history: List[Tuple[str, str]],
 
3
  from typing import Iterator, List, Tuple, Dict, Any
4
 
5
  import gradio as gr
6
+ import spaces
7
  import torch
8
  from transformers import (
9
  TrainingArguments,
 
407
  pretrained_model = get_peft_model(pretrained, lora_config)
408
  print(pretrained_model)
409
 
410
+ @spaces.GPU(duration=30, queue=False)
411
  def run_training() -> str:
412
  """
413
  Hàm huấn luyện mô hình sử dụng GPU với thời gian hạn chế.
 
467
  return "Huấn luyện hoàn tất hoặc đã tiếp tục từ checkpoint."
468
 
469
  # Hàm Tự Động Hóa Việc Gọi Lặp Lại Hàm Huấn Luyện
470
+ @spaces.GPU(duration=30, queue=False)
471
  def continuous_training(total_steps=300, steps_per_call=50):
472
  """
473
  Hàm tự động gọi lại `run_training` để hoàn thành quá trình huấn luyện.
 
539
 
540
  # ---------------------------- Giao Diện Gradio ---------------------------- #
541
 
542
+ @spaces.GPU(duration=30, queue=False)
543
  def generate(
544
  message: str,
545
  chat_history: List[Tuple[str, str]],