File size: 3,814 Bytes
2b35f7d
0b40748
 
 
 
 
3f9d883
0b40748
 
3f9d883
0b40748
 
3f9d883
0b40748
92992ea
0b40748
3f9d883
0b40748
 
 
 
 
 
 
 
 
 
 
3f9d883
0b40748
 
3f9d883
0b40748
 
 
 
 
 
3f9d883
0b40748
 
 
 
3f9d883
0b40748
 
 
 
 
 
3f9d883
 
0b40748
 
 
 
 
 
 
 
 
 
3f9d883
0b40748
3f9d883
0b40748
 
3f9d883
 
0b40748
3f9d883
0b40748
3f9d883
0b40748
3f9d883
 
 
 
 
0b40748
3f9d883
 
0b40748
3f9d883
 
 
 
0b40748
3f9d883
2b35f7d
3f9d883
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset

# ✅ بررسی سخت‌افزار (CPU/GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# ✅ تابع اجرای ترینینگ (قفل شده تا پایان)
def train_model(dataset_url, model_url, epochs):
    try:
        # 🚀 بارگیری مدل و توکنایزر
        tokenizer = AutoTokenizer.from_pretrained(model_url)
        model = AutoModelForCausalLM.from_pretrained(model_url).to(device)

        # ✅ تنظیم LoRA برای کاهش مصرف حافظه
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=8,
            lora_alpha=32,
            lora_dropout=0.1,
            target_modules=["q_proj", "v_proj"]
        )

        model = get_peft_model(model, lora_config)
        model.to(device)

        # ✅ بارگیری دیتاست
        dataset = load_dataset(dataset_url)

        # ✅ توکنایز کردن داده‌ها
        def tokenize_function(examples):
            return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=256)

        tokenized_datasets = dataset.map(tokenize_function, batched=True)
        train_dataset = tokenized_datasets["train"]

        # ✅ تنظیمات ترینینگ
        training_args = TrainingArguments(
            output_dir="./deepseek_lora_cpu",
            evaluation_strategy="epoch",
            learning_rate=5e-4,
            per_device_train_batch_size=1,  # کاهش مصرف RAM
            per_device_eval_batch_size=1,
            num_train_epochs=int(epochs),
            save_strategy="epoch",
            save_total_limit=2,
            logging_dir="./logs",
            logging_steps=10,
            fp16=False,  # عدم استفاده از FP16 روی CPU
            gradient_checkpointing=True,  # ذخیره حافظه
            optim="adamw_torch",
            report_to="none"
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset
        )

        # 🚀 شروع ترینینگ (قفل شده تا پایان)
        trainer.train()
        trainer.save_model("./deepseek_lora_finetuned")  # ذخیره نهایی مدل
        tokenizer.save_pretrained("./deepseek_lora_finetuned")

        return "✅ ترینینگ کامل شد! مدل ذخیره شد."

    except Exception as e:
        return f"❌ خطا: {str(e)}"

# ✅ Gradio UI با دکمه‌ی غیرفعال‌شونده
with gr.Blocks() as app:
    gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - (بدون توقف تا پایان)")

    dataset_url = gr.Textbox(label="Dataset URL (Hugging Face)", placeholder="مثال: samsum")
    model_url = gr.Textbox(label="Model URL (Hugging Face)", placeholder="مثال: deepseek-ai/deepseek-r1")
    epochs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="تعداد Epochs")

    train_button = gr.Button("شروع ترینینگ", interactive=True)
    output_text = gr.Textbox(label="وضعیت ترینینگ")

    # 🚀 بعد از کلیک دکمه را غیرفعال کنیم تا کار متوقف نشود
    def disable_button(*args):
        train_button.interactive = False  # غیرفعال کردن دکمه
        return train_model(*args)

    train_button.click(disable_button, inputs=[dataset_url, model_url, epochs], outputs=output_text)

# ✅ اجرای Gradio در حالت قفل شده
app.queue()  # این خط تضمین می‌کند که پردازش متوقف نشود
app.launch(server_name="0.0.0.0", server_port=7860, share=True, blocking=True)