Tuning / app.py
hackergeek's picture
Update app.py
e88f543 verified
raw
history blame
4.14 kB
import torch
import gradio as gr
import multiprocessing
import os
import time
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
device = "cpu"
training_process = None
log_file = "training_status.log"
def log_status(message):
with open(log_file, "w") as f:
f.write(message)
def read_status():
if os.path.exists(log_file):
with open(log_file, "r") as f:
return f.read()
return "⏳ در انتظار شروع ترینینگ..."
def train_model(dataset_url, model_url, epochs):
try:
log_status("🚀 در حال بارگیری مدل...")
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_url, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj"]
)
model = get_peft_model(model, lora_config)
model.to(device)
dataset = load_dataset(dataset_url)
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=256)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"]
training_args = TrainingArguments(
output_dir="./deepseek_lora_cpu",
evaluation_strategy="epoch",
learning_rate=5e-4,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
num_train_epochs=int(epochs),
save_strategy="epoch",
save_total_limit=2,
logging_dir="./logs",
logging_steps=10,
fp16=False,
gradient_checkpointing=True,
optim="adamw_torch",
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset
)
log_status("🚀 ترینینگ شروع شد...")
for epoch in range(int(epochs)):
log_status(f"🔄 در حال اجرا: Epoch {epoch+1}/{epochs}...")
trainer.train(resume_from_checkpoint=True)
trainer.save_model(f"./deepseek_lora_finetuned_epoch_{epoch+1}")
log_status("✅ ترینینگ کامل شد!")
except Exception as e:
log_status(f"❌ خطا: {str(e)}")
def start_training(dataset_url, model_url, epochs):
global training_process
if training_process is None or not training_process.is_alive():
training_process = multiprocessing.Process(target=train_model, args=(dataset_url, model_url, epochs))
training_process.start()
return "🚀 ترینینگ شروع شد!"
else:
return "⚠ ترینینگ در حال اجرا است!"
def update_status():
return read_status()
with gr.Blocks() as app:
gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - نمایش وضعیت لحظه‌ای")
dataset_url = gr.Textbox(label="Dataset URL (Hugging Face)", placeholder="مثال: samsum")
model_url = gr.Textbox(label="Model URL (Hugging Face)", placeholder="مثال: deepseek-ai/deepseek-r1")
epochs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="تعداد Epochs")
train_button = gr.Button("شروع ترینینگ")
output_text = gr.Textbox(label="وضعیت ترینینگ")
train_button.click(start_training, inputs=[dataset_url, model_url, epochs], outputs=output_text)
# ✅ نمایش وضعیت لحظه‌ای ترینینگ
status_box = gr.Textbox(label="مرحله فعلی ترینینگ", interactive=False)
refresh_button = gr.Button("🔄 به‌روزرسانی وضعیت")
refresh_button.click(update_status, inputs=[], outputs=status_box)
app.queue()
app.launch(server_name="0.0.0.0", server_port=7860, share=True)