hackergeek commited on
Commit
e784c2f
·
verified ·
1 Parent(s): 99cd08e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -4,15 +4,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments,
4
  from peft import get_peft_model, LoraConfig, TaskType
5
  from datasets import load_dataset
6
 
7
- # ✅ بررسی سخت‌افزار (CPU/GPU)
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
- # ✅ تابع اجرای ترینینگ (قفل شده تا پایان)
11
  def train_model(dataset_url, model_url, epochs):
12
  try:
13
- # 🚀 بارگیری مدل و توکنایزر
14
- tokenizer = AutoTokenizer.from_pretrained(model_url)
15
- model = AutoModelForCausalLM.from_pretrained(model_url).to(device)
16
 
17
  # ✅ تنظیم LoRA برای کاهش مصرف حافظه
18
  lora_config = LoraConfig(
@@ -41,15 +39,15 @@ def train_model(dataset_url, model_url, epochs):
41
  output_dir="./deepseek_lora_cpu",
42
  evaluation_strategy="epoch",
43
  learning_rate=5e-4,
44
- per_device_train_batch_size=1, # کاهش مصرف RAM
45
  per_device_eval_batch_size=1,
46
  num_train_epochs=int(epochs),
47
  save_strategy="epoch",
48
  save_total_limit=2,
49
  logging_dir="./logs",
50
  logging_steps=10,
51
- fp16=False, # عدم استفاده از FP16 روی CPU
52
- gradient_checkpointing=True, # ذخیره حافظه
53
  optim="adamw_torch",
54
  report_to="none"
55
  )
@@ -62,7 +60,7 @@ def train_model(dataset_url, model_url, epochs):
62
 
63
  # 🚀 شروع ترینینگ (قفل شده تا پایان)
64
  trainer.train()
65
- trainer.save_model("./deepseek_lora_finetuned") # ذخیره نهایی مدل
66
  tokenizer.save_pretrained("./deepseek_lora_finetuned")
67
 
68
  return "✅ ترینینگ کامل شد! مدل ذخیره شد."
@@ -70,7 +68,7 @@ def train_model(dataset_url, model_url, epochs):
70
  except Exception as e:
71
  return f"❌ خطا: {str(e)}"
72
 
73
- # ✅ Gradio UI با دکمه‌ی غیرفعال‌شونده
74
  with gr.Blocks() as app:
75
  gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - (بدون توقف تا پایان)")
76
 
@@ -81,13 +79,11 @@ with gr.Blocks() as app:
81
  train_button = gr.Button("شروع ترینینگ", interactive=True)
82
  output_text = gr.Textbox(label="وضعیت ترینینگ")
83
 
84
- # 🚀 بعد از کلیک دکمه را غیرفعال کنیم تا کار متوقف نشود
85
  def disable_button(*args):
86
- train_button.interactive = False # غیرفعال کردن دکمه
87
  return train_model(*args)
88
 
89
  train_button.click(disable_button, inputs=[dataset_url, model_url, epochs], outputs=output_text)
90
 
91
- # ✅ اجرای Gradio در حالت قفل شده
92
- app.queue() # این خط تضمین می‌کند که پردازش متوقف نشود
93
- app.launch(server_name="0.0.0.0", server_port=7860, share=True) # ❌ `blocking=True` حذف شد
 
4
  from peft import get_peft_model, LoraConfig, TaskType
5
  from datasets import load_dataset
6
 
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
 
9
  def train_model(dataset_url, model_url, epochs):
10
  try:
11
+ # 🚀 بارگیری مدل و توکنایزر با `trust_remote_code=True`
12
+ tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
13
+ model = AutoModelForCausalLM.from_pretrained(model_url, trust_remote_code=True).to(device)
14
 
15
  # ✅ تنظیم LoRA برای کاهش مصرف حافظه
16
  lora_config = LoraConfig(
 
39
  output_dir="./deepseek_lora_cpu",
40
  evaluation_strategy="epoch",
41
  learning_rate=5e-4,
42
+ per_device_train_batch_size=1,
43
  per_device_eval_batch_size=1,
44
  num_train_epochs=int(epochs),
45
  save_strategy="epoch",
46
  save_total_limit=2,
47
  logging_dir="./logs",
48
  logging_steps=10,
49
+ fp16=False,
50
+ gradient_checkpointing=True,
51
  optim="adamw_torch",
52
  report_to="none"
53
  )
 
60
 
61
  # 🚀 شروع ترینینگ (قفل شده تا پایان)
62
  trainer.train()
63
+ trainer.save_model("./deepseek_lora_finetuned")
64
  tokenizer.save_pretrained("./deepseek_lora_finetuned")
65
 
66
  return "✅ ترینینگ کامل شد! مدل ذخیره شد."
 
68
  except Exception as e:
69
  return f"❌ خطا: {str(e)}"
70
 
71
+ # ✅ رابط کاربری Gradio
72
  with gr.Blocks() as app:
73
  gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - (بدون توقف تا پایان)")
74
 
 
79
  train_button = gr.Button("شروع ترینینگ", interactive=True)
80
  output_text = gr.Textbox(label="وضعیت ترینینگ")
81
 
 
82
  def disable_button(*args):
83
+ train_button.interactive = False
84
  return train_model(*args)
85
 
86
  train_button.click(disable_button, inputs=[dataset_url, model_url, epochs], outputs=output_text)
87
 
88
+ app.queue()
89
+ app.launch(server_name="0.0.0.0", server_port=7860, share=True)