File size: 3,513 Bytes
2b35f7d
0b40748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b35f7d
0b40748
2b35f7d
 
 
0b40748
 
 
2b35f7d
 
 
0b40748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b35f7d
0b40748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b35f7d
0b40748
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
from bitsandbytes import BitsAndBytesConfig

# βœ… Check if a GPU is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# βœ… Function to start training
def train_model(dataset_url, model_url, epochs):
    try:
        # Load the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_url)

        # βœ… Load model with 4-bit quantization for CPU efficiency
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True if device == "cuda" else False,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True
        )

        model = AutoModelForCausalLM.from_pretrained(
            model_url,
            quantization_config=bnb_config if device == "cuda" else None,
            device_map=device
        )

        # βœ… Apply LoRA for efficient training
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=8,
            lora_alpha=32,
            lora_dropout=0.1,
            target_modules=["q_proj", "v_proj"]
        )

        model = get_peft_model(model, lora_config)
        model.to(device)

        # βœ… Load dataset
        dataset = load_dataset(dataset_url)

        # βœ… Tokenization function
        def tokenize_function(examples):
            return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=256)

        tokenized_datasets = dataset.map(tokenize_function, batched=True)
        train_dataset = tokenized_datasets["train"]

        # βœ… Training Arguments
        training_args = TrainingArguments(
            output_dir="./deepseek_lora_cpu",
            evaluation_strategy="epoch",
            learning_rate=5e-4,
            per_device_train_batch_size=1,
            per_device_eval_batch_size=1,
            num_train_epochs=int(epochs),
            save_strategy="epoch",
            save_total_limit=2,
            logging_dir="./logs",
            logging_steps=10,
            fp16=False,
            gradient_checkpointing=True,
            optim="adamw_torch",
            report_to="none"
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset
        )

        # βœ… Start Training
        trainer.train()

        # βœ… Save the Fine-Tuned Model
        model.save_pretrained("./deepseek_lora_finetuned")
        tokenizer.save_pretrained("./deepseek_lora_finetuned")

        return "βœ… Training Completed! Model saved successfully."
    
    except Exception as e:
        return f"❌ Error: {str(e)}"

# βœ… Gradio UI
with gr.Blocks() as app:
    gr.Markdown("# πŸš€ AutoTrain DeepSeek R1 (CPU)")

    dataset_url = gr.Textbox(label="Dataset URL (Hugging Face)", placeholder="e.g. samsum")
    model_url = gr.Textbox(label="Model URL (Hugging Face)", placeholder="e.g. deepseek-ai/deepseek-r1")
    epochs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Training Epochs")

    train_button = gr.Button("Start Training")
    output_text = gr.Textbox(label="Training Output")

    train_button.click(train_model, inputs=[dataset_url, model_url, epochs], outputs=output_text)

# βœ… Launch the app
app.launch(server_name="0.0.0.0", server_port=7860)