Spaces:
Running
Running
import torch | |
import gradio as gr | |
import multiprocessing | |
import os | |
import time | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer | |
from peft import get_peft_model, LoraConfig, TaskType | |
from datasets import load_dataset | |
device = "cpu" | |
training_process = None | |
log_file = "training_status.log" | |
def log_status(message): | |
with open(log_file, "w") as f: | |
f.write(message) | |
def read_status(): | |
if os.path.exists(log_file): | |
with open(log_file, "r") as f: | |
return f.read() | |
return "⏳ در انتظار شروع ترینینگ..." | |
def train_model(dataset_url, model_url, epochs): | |
try: | |
log_status("🚀 در حال بارگیری مدل...") | |
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_url, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu" | |
) | |
lora_config = LoraConfig( | |
task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj"] | |
) | |
model = get_peft_model(model, lora_config) | |
model.to(device) | |
dataset = load_dataset(dataset_url) | |
def tokenize_function(examples): | |
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=256) | |
tokenized_datasets = dataset.map(tokenize_function, batched=True) | |
train_dataset = tokenized_datasets["train"] | |
training_args = TrainingArguments( | |
output_dir="./deepseek_lora_cpu", | |
evaluation_strategy="epoch", | |
learning_rate=5e-4, | |
per_device_train_batch_size=1, | |
per_device_eval_batch_size=1, | |
num_train_epochs=int(epochs), | |
save_strategy="epoch", | |
save_total_limit=2, | |
logging_dir="./logs", | |
logging_steps=10, | |
fp16=False, | |
gradient_checkpointing=True, | |
optim="adamw_torch", | |
report_to="none" | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset | |
) | |
log_status("🚀 ترینینگ شروع شد...") | |
for epoch in range(int(epochs)): | |
log_status(f"🔄 در حال اجرا: Epoch {epoch+1}/{epochs}...") | |
trainer.train(resume_from_checkpoint=True) | |
trainer.save_model(f"./deepseek_lora_finetuned_epoch_{epoch+1}") | |
log_status("✅ ترینینگ کامل شد!") | |
except Exception as e: | |
log_status(f"❌ خطا: {str(e)}") | |
def start_training(dataset_url, model_url, epochs): | |
global training_process | |
if training_process is None or not training_process.is_alive(): | |
training_process = multiprocessing.Process(target=train_model, args=(dataset_url, model_url, epochs)) | |
training_process.start() | |
return "🚀 ترینینگ شروع شد!" | |
else: | |
return "⚠ ترینینگ در حال اجرا است!" | |
def update_status(): | |
return read_status() | |
with gr.Blocks() as app: | |
gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - نمایش وضعیت لحظهای") | |
dataset_url = gr.Textbox(label="Dataset URL (Hugging Face)", placeholder="مثال: samsum") | |
model_url = gr.Textbox(label="Model URL (Hugging Face)", placeholder="مثال: deepseek-ai/deepseek-r1") | |
epochs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="تعداد Epochs") | |
train_button = gr.Button("شروع ترینینگ") | |
output_text = gr.Textbox(label="وضعیت ترینینگ") | |
train_button.click(start_training, inputs=[dataset_url, model_url, epochs], outputs=output_text) | |
# ✅ نمایش وضعیت لحظهای ترینینگ | |
status_box = gr.Textbox(label="مرحله فعلی ترینینگ", interactive=False) | |
refresh_button = gr.Button("🔄 بهروزرسانی وضعیت") | |
refresh_button.click(update_status, inputs=[], outputs=status_box) | |
app.queue() | |
app.launch(server_name="0.0.0.0", server_port=7860, share=True) |