Spaces:
Running
Running
File size: 4,998 Bytes
2b35f7d 0b40748 a51c773 0b40748 a51c773 006af89 a51c773 006af89 a51c773 0b40748 006af89 0b40748 a51c773 006af89 e784c2f 777e328 a51c773 777e328 0b40748 006af89 0b40748 006af89 0b40748 006af89 0b40748 71e1bba 0b40748 71e1bba 0b40748 a51c773 0b40748 a51c773 e784c2f 0b40748 006af89 71e1bba 0b40748 a51c773 0b40748 a51c773 3f9d883 0b40748 a51c773 006af89 a51c773 0b40748 006af89 e88f543 006af89 0b40748 a51c773 3f9d883 006af89 a51c773 006af89 0b40748 006af89 2b35f7d 006af89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
import gradio as gr
import multiprocessing
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
device = "cpu"
training_process = None
log_file = "training_status.log"
# Logging function
def log_status(message):
with open(log_file, "w") as f:
f.write(message)
# Read training status
def read_status():
if os.path.exists(log_file):
with open(log_file, "r") as f:
return f.read()
return "⏳ در انتظار شروع ترینینگ..."
# Function to find the text column dynamically
def find_text_column(dataset):
sample = dataset["train"][0] # Get the first row of the training dataset
for column in sample.keys():
if isinstance(sample[column], str): # Find the first text-like column
return column
return None # No valid text column found
# Model training function
def train_model(dataset_url, model_url, epochs):
try:
log_status("🚀 در حال بارگیری مدل...")
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_url, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"]
)
model = get_peft_model(model, lora_config)
model.to(device)
dataset = load_dataset(dataset_url)
# Automatically detect the correct text column
text_column = find_text_column(dataset)
if not text_column:
log_status("❌ خطا: ستون متنی در دیتاست یافت نشد!")
return
def tokenize_function(examples):
return tokenizer(examples[text_column], truncation=True, padding="max_length", max_length=256)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"]
# Automatically check for validation dataset
eval_dataset = tokenized_datasets["validation"] if "validation" in tokenized_datasets else None
training_args = TrainingArguments(
output_dir="./deepseek_lora_cpu",
evaluation_strategy="epoch" if eval_dataset else "no", # Enable evaluation if validation data exists
learning_rate=5e-4,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
num_train_epochs=int(epochs),
save_strategy="epoch",
save_total_limit=2,
logging_dir="./logs",
logging_steps=10,
fp16=False,
gradient_checkpointing=True,
optim="adamw_torch",
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset # Add eval dataset if available
)
log_status("🚀 ترینینگ شروع شد...")
for epoch in range(int(epochs)):
log_status(f"🔄 در حال اجرا: Epoch {epoch+1}/{epochs}...")
trainer.train(resume_from_checkpoint=True)
trainer.save_model(f"./deepseek_lora_finetuned_epoch_{epoch+1}")
log_status("✅ ترینینگ کامل شد!")
except Exception as e:
log_status(f"❌ خطا: {str(e)}")
# Start training in a separate process
def start_training(dataset_url, model_url, epochs):
global training_process
if training_process is None or not training_process.is_alive():
training_process = multiprocessing.Process(target=train_model, args=(dataset_url, model_url, epochs))
training_process.start()
return "🚀 ترینینگ شروع شد!"
else:
return "⚠ ترینینگ در حال اجرا است!"
# Function to update the status
def update_status():
return read_status()
# Gradio UI
with gr.Blocks() as app:
gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - نمایش وضعیت لحظهای")
with gr.Row():
dataset_input = gr.Textbox(label="📂 لینک دیتاست (Hugging Face)")
model_input = gr.Textbox(label="🤖 مدل پایه (Hugging Face)")
epochs_input = gr.Number(label="🔄 تعداد Epochs", value=3)
start_button = gr.Button("🚀 شروع ترینینگ")
status_output = gr.Textbox(label="📢 وضعیت ترینینگ", interactive=False)
start_button.click(start_training, inputs=[dataset_input, model_input, epochs_input], outputs=status_output)
status_button = gr.Button("🔄 بروزرسانی وضعیت")
status_button.click(update_status, outputs=status_output)
app.launch() |