import torch import gradio as gr import multiprocessing import os from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer from peft import get_peft_model, LoraConfig, TaskType from datasets import load_dataset device = "cpu" training_process = None log_file = "training_status.log" # Logging function def log_status(message): with open(log_file, "w") as f: f.write(message) # Read training status def read_status(): if os.path.exists(log_file): with open(log_file, "r") as f: return f.read() return "⏳ در انتظار شروع ترینینگ..." # Function to find the text column dynamically def find_text_column(dataset): sample = dataset["train"][0] # Get the first row of the training dataset for column in sample.keys(): if isinstance(sample[column], str): # Find the first text-like column return column return None # No valid text column found # Model training function def train_model(dataset_url, model_url, epochs): try: log_status("🚀 در حال بارگیری مدل...") tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_url, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu" ) lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj"] ) model = get_peft_model(model, lora_config) model.to(device) dataset = load_dataset(dataset_url) # Automatically detect the correct text column text_column = find_text_column(dataset) if not text_column: log_status("❌ خطا: ستون متنی در دیتاست یافت نشد!") return def tokenize_function(examples): return tokenizer(examples[text_column], truncation=True, padding="max_length", max_length=256) tokenized_datasets = dataset.map(tokenize_function, batched=True) train_dataset = tokenized_datasets["train"] # Automatically check for validation dataset eval_dataset = tokenized_datasets["validation"] if "validation" in tokenized_datasets else None training_args = TrainingArguments( output_dir="./deepseek_lora_cpu", evaluation_strategy="epoch" if eval_dataset else "no", # Enable evaluation if validation data exists learning_rate=5e-4, per_device_train_batch_size=1, per_device_eval_batch_size=1, num_train_epochs=int(epochs), save_strategy="epoch", save_total_limit=2, logging_dir="./logs", logging_steps=10, fp16=False, gradient_checkpointing=True, optim="adamw_torch", report_to="none" ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset # Add eval dataset if available ) log_status("🚀 ترینینگ شروع شد...") for epoch in range(int(epochs)): log_status(f"🔄 در حال اجرا: Epoch {epoch+1}/{epochs}...") trainer.train(resume_from_checkpoint=True) trainer.save_model(f"./deepseek_lora_finetuned_epoch_{epoch+1}") log_status("✅ ترینینگ کامل شد!") except Exception as e: log_status(f"❌ خطا: {str(e)}") # Start training in a separate process def start_training(dataset_url, model_url, epochs): global training_process if training_process is None or not training_process.is_alive(): training_process = multiprocessing.Process(target=train_model, args=(dataset_url, model_url, epochs)) training_process.start() return "🚀 ترینینگ شروع شد!" else: return "⚠ ترینینگ در حال اجرا است!" # Function to update the status def update_status(): return read_status() # Gradio UI with gr.Blocks() as app: gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - نمایش وضعیت لحظه‌ای") with gr.Row(): dataset_input = gr.Textbox(label="📂 لینک دیتاست (Hugging Face)") model_input = gr.Textbox(label="🤖 مدل پایه (Hugging Face)") epochs_input = gr.Number(label="🔄 تعداد Epochs", value=3) start_button = gr.Button("🚀 شروع ترینینگ") status_output = gr.Textbox(label="📢 وضعیت ترینینگ", interactive=False) start_button.click(start_training, inputs=[dataset_input, model_input, epochs_input], outputs=status_output) status_button = gr.Button("🔄 بروزرسانی وضعیت") status_button.click(update_status, outputs=status_output) app.launch()