Tuning / app.py
hackergeek's picture
Update app.py
71e1bba verified
import torch
import gradio as gr
import multiprocessing
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
device = "cpu"
training_process = None
log_file = "training_status.log"
# Logging function
def log_status(message):
with open(log_file, "w") as f:
f.write(message)
# Read training status
def read_status():
if os.path.exists(log_file):
with open(log_file, "r") as f:
return f.read()
return "⏳ در انتظار شروع ترینینگ..."
# Function to find the text column dynamically
def find_text_column(dataset):
sample = dataset["train"][0] # Get the first row of the training dataset
for column in sample.keys():
if isinstance(sample[column], str): # Find the first text-like column
return column
return None # No valid text column found
# Model training function
def train_model(dataset_url, model_url, epochs):
try:
log_status("🚀 در حال بارگیری مدل...")
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_url, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"]
)
model = get_peft_model(model, lora_config)
model.to(device)
dataset = load_dataset(dataset_url)
# Automatically detect the correct text column
text_column = find_text_column(dataset)
if not text_column:
log_status("❌ خطا: ستون متنی در دیتاست یافت نشد!")
return
def tokenize_function(examples):
return tokenizer(examples[text_column], truncation=True, padding="max_length", max_length=256)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"]
# Automatically check for validation dataset
eval_dataset = tokenized_datasets["validation"] if "validation" in tokenized_datasets else None
training_args = TrainingArguments(
output_dir="./deepseek_lora_cpu",
evaluation_strategy="epoch" if eval_dataset else "no", # Enable evaluation if validation data exists
learning_rate=5e-4,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
num_train_epochs=int(epochs),
save_strategy="epoch",
save_total_limit=2,
logging_dir="./logs",
logging_steps=10,
fp16=False,
gradient_checkpointing=True,
optim="adamw_torch",
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset # Add eval dataset if available
)
log_status("🚀 ترینینگ شروع شد...")
for epoch in range(int(epochs)):
log_status(f"🔄 در حال اجرا: Epoch {epoch+1}/{epochs}...")
trainer.train(resume_from_checkpoint=True)
trainer.save_model(f"./deepseek_lora_finetuned_epoch_{epoch+1}")
log_status("✅ ترینینگ کامل شد!")
except Exception as e:
log_status(f"❌ خطا: {str(e)}")
# Start training in a separate process
def start_training(dataset_url, model_url, epochs):
global training_process
if training_process is None or not training_process.is_alive():
training_process = multiprocessing.Process(target=train_model, args=(dataset_url, model_url, epochs))
training_process.start()
return "🚀 ترینینگ شروع شد!"
else:
return "⚠ ترینینگ در حال اجرا است!"
# Function to update the status
def update_status():
return read_status()
# Gradio UI
with gr.Blocks() as app:
gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - نمایش وضعیت لحظه‌ای")
with gr.Row():
dataset_input = gr.Textbox(label="📂 لینک دیتاست (Hugging Face)")
model_input = gr.Textbox(label="🤖 مدل پایه (Hugging Face)")
epochs_input = gr.Number(label="🔄 تعداد Epochs", value=3)
start_button = gr.Button("🚀 شروع ترینینگ")
status_output = gr.Textbox(label="📢 وضعیت ترینینگ", interactive=False)
start_button.click(start_training, inputs=[dataset_input, model_input, epochs_input], outputs=status_output)
status_button = gr.Button("🔄 بروزرسانی وضعیت")
status_button.click(update_status, outputs=status_output)
app.launch()