File size: 4,998 Bytes
2b35f7d
0b40748
a51c773
 
0b40748
 
 
 
a51c773
 
 
 
006af89
a51c773
 
 
 
006af89
a51c773
 
 
 
 
0b40748
006af89
 
 
 
 
 
 
 
 
0b40748
 
a51c773
006af89
e784c2f
777e328
a51c773
777e328
0b40748
 
006af89
 
 
 
 
0b40748
 
 
 
 
006af89
 
 
 
 
 
 
0b40748
006af89
0b40748
 
 
 
71e1bba
 
 
0b40748
 
71e1bba
0b40748
a51c773
0b40748
 
 
 
 
 
a51c773
e784c2f
0b40748
 
 
 
 
006af89
 
71e1bba
 
0b40748
 
a51c773
0b40748
a51c773
 
 
 
 
 
3f9d883
0b40748
a51c773
 
006af89
a51c773
 
 
 
 
 
 
 
0b40748
006af89
e88f543
 
 
006af89
0b40748
a51c773
3f9d883
006af89
 
 
 
a51c773
006af89
 
0b40748
006af89
 
 
2b35f7d
006af89
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import gradio as gr
import multiprocessing
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset

device = "cpu"
training_process = None
log_file = "training_status.log"

# Logging function
def log_status(message):
    with open(log_file, "w") as f:
        f.write(message)

# Read training status
def read_status():
    if os.path.exists(log_file):
        with open(log_file, "r") as f:
            return f.read()
    return "⏳ در انتظار شروع ترینینگ..."

# Function to find the text column dynamically
def find_text_column(dataset):
    sample = dataset["train"][0]  # Get the first row of the training dataset
    for column in sample.keys():
        if isinstance(sample[column], str):  # Find the first text-like column
            return column
    return None  # No valid text column found

# Model training function
def train_model(dataset_url, model_url, epochs):
    try:
        log_status("🚀 در حال بارگیری مدل...")
        
        tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_url, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
        )

        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=8,
            lora_alpha=32,
            lora_dropout=0.1,
            target_modules=["q_proj", "v_proj"]
        )
        model = get_peft_model(model, lora_config)
        model.to(device)

        dataset = load_dataset(dataset_url)

        # Automatically detect the correct text column
        text_column = find_text_column(dataset)
        if not text_column:
            log_status("❌ خطا: ستون متنی در دیتاست یافت نشد!")
            return
        
        def tokenize_function(examples):
            return tokenizer(examples[text_column], truncation=True, padding="max_length", max_length=256)

        tokenized_datasets = dataset.map(tokenize_function, batched=True)
        train_dataset = tokenized_datasets["train"]

        # Automatically check for validation dataset
        eval_dataset = tokenized_datasets["validation"] if "validation" in tokenized_datasets else None

        training_args = TrainingArguments(
            output_dir="./deepseek_lora_cpu",
            evaluation_strategy="epoch" if eval_dataset else "no",  # Enable evaluation if validation data exists
            learning_rate=5e-4,
            per_device_train_batch_size=1,
            per_device_eval_batch_size=1,
            num_train_epochs=int(epochs),
            save_strategy="epoch",
            save_total_limit=2,
            logging_dir="./logs",
            logging_steps=10,
            fp16=False,
            gradient_checkpointing=True,
            optim="adamw_torch",
            report_to="none"
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset  # Add eval dataset if available
        )

        log_status("🚀 ترینینگ شروع شد...")

        for epoch in range(int(epochs)):
            log_status(f"🔄 در حال اجرا: Epoch {epoch+1}/{epochs}...")
            trainer.train(resume_from_checkpoint=True)
            trainer.save_model(f"./deepseek_lora_finetuned_epoch_{epoch+1}")

        log_status("✅ ترینینگ کامل شد!")

    except Exception as e:
        log_status(f"❌ خطا: {str(e)}")

# Start training in a separate process
def start_training(dataset_url, model_url, epochs):
    global training_process
    if training_process is None or not training_process.is_alive():
        training_process = multiprocessing.Process(target=train_model, args=(dataset_url, model_url, epochs))
        training_process.start()
        return "🚀 ترینینگ شروع شد!"
    else:
        return "⚠ ترینینگ در حال اجرا است!"

# Function to update the status
def update_status():
    return read_status()

# Gradio UI
with gr.Blocks() as app:
    gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - نمایش وضعیت لحظه‌ای")

    with gr.Row():
        dataset_input = gr.Textbox(label="📂 لینک دیتاست (Hugging Face)")
        model_input = gr.Textbox(label="🤖 مدل پایه (Hugging Face)")
        epochs_input = gr.Number(label="🔄 تعداد Epochs", value=3)

    start_button = gr.Button("🚀 شروع ترینینگ")
    status_output = gr.Textbox(label="📢 وضعیت ترینینگ", interactive=False)

    start_button.click(start_training, inputs=[dataset_input, model_input, epochs_input], outputs=status_output)
    status_button = gr.Button("🔄 بروزرسانی وضعیت")
    status_button.click(update_status, outputs=status_output)

app.launch()