File size: 4,136 Bytes
2b35f7d
0b40748
a51c773
 
 
0b40748
 
 
 
a51c773
 
 
 
 
 
 
 
 
 
 
 
 
0b40748
 
 
a51c773
e784c2f
777e328
a51c773
777e328
0b40748
 
a51c773
0b40748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a51c773
0b40748
 
 
 
 
 
a51c773
e784c2f
0b40748
 
 
 
 
a51c773
 
0b40748
 
 
a51c773
0b40748
a51c773
 
 
 
 
 
3f9d883
0b40748
a51c773
 
 
 
 
 
 
 
 
 
0b40748
e88f543
 
 
0b40748
a51c773
3f9d883
 
 
 
0b40748
a51c773
3f9d883
0b40748
a51c773
 
 
 
e88f543
0b40748
e88f543
2b35f7d
e784c2f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import gradio as gr
import multiprocessing
import os
import time
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset

device = "cpu"
training_process = None
log_file = "training_status.log"

def log_status(message):
    with open(log_file, "w") as f:
        f.write(message)

def read_status():
    if os.path.exists(log_file):
        with open(log_file, "r") as f:
            return f.read()
    return "⏳ در انتظار شروع ترینینگ..."

def train_model(dataset_url, model_url, epochs):
    try:
        log_status("🚀 در حال بارگیری مدل...")
        tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_url, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
        )

        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj"]
        )

        model = get_peft_model(model, lora_config)
        model.to(device)

        dataset = load_dataset(dataset_url)
        def tokenize_function(examples):
            return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=256)

        tokenized_datasets = dataset.map(tokenize_function, batched=True)
        train_dataset = tokenized_datasets["train"]

        training_args = TrainingArguments(
            output_dir="./deepseek_lora_cpu",
            evaluation_strategy="epoch",
            learning_rate=5e-4,
            per_device_train_batch_size=1,
            per_device_eval_batch_size=1,
            num_train_epochs=int(epochs),
            save_strategy="epoch",
            save_total_limit=2,
            logging_dir="./logs",
            logging_steps=10,
            fp16=False,
            gradient_checkpointing=True,
            optim="adamw_torch",
            report_to="none"
        )

        trainer = Trainer(
            model=model, 
            args=training_args, 
            train_dataset=train_dataset
        )

        log_status("🚀 ترینینگ شروع شد...")

        for epoch in range(int(epochs)):
            log_status(f"🔄 در حال اجرا: Epoch {epoch+1}/{epochs}...")
            trainer.train(resume_from_checkpoint=True)
            trainer.save_model(f"./deepseek_lora_finetuned_epoch_{epoch+1}")

        log_status("✅ ترینینگ کامل شد!")

    except Exception as e:
        log_status(f"❌ خطا: {str(e)}")

def start_training(dataset_url, model_url, epochs):
    global training_process
    if training_process is None or not training_process.is_alive():
        training_process = multiprocessing.Process(target=train_model, args=(dataset_url, model_url, epochs))
        training_process.start()
        return "🚀 ترینینگ شروع شد!"
    else:
        return "⚠ ترینینگ در حال اجرا است!"

def update_status():
    return read_status()

with gr.Blocks() as app:
    gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - نمایش وضعیت لحظه‌ای")

    dataset_url = gr.Textbox(label="Dataset URL (Hugging Face)", placeholder="مثال: samsum")
    model_url = gr.Textbox(label="Model URL (Hugging Face)", placeholder="مثال: deepseek-ai/deepseek-r1")
    epochs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="تعداد Epochs")

    train_button = gr.Button("شروع ترینینگ")
    output_text = gr.Textbox(label="وضعیت ترینینگ")

    train_button.click(start_training, inputs=[dataset_url, model_url, epochs], outputs=output_text)

    # ✅ نمایش وضعیت لحظه‌ای ترینینگ
    status_box = gr.Textbox(label="مرحله فعلی ترینینگ", interactive=False)
    refresh_button = gr.Button("🔄 به‌روزرسانی وضعیت")

    refresh_button.click(update_status, inputs=[], outputs=status_box)

app.queue()
app.launch(server_name="0.0.0.0", server_port=7860, share=True)