Tuning / app.py
hackergeek's picture
Update app.py
006af89 verified
raw
history blame
4.7 kB
import torch
import gradio as gr
import multiprocessing
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
device = "cpu"
training_process = None
log_file = "training_status.log"
# Logging function
def log_status(message):
with open(log_file, "w") as f:
f.write(message)
# Read training status
def read_status():
if os.path.exists(log_file):
with open(log_file, "r") as f:
return f.read()
return "⏳ در انتظار شروع ترینینگ..."
# Function to find the text column dynamically
def find_text_column(dataset):
sample = dataset["train"][0] # Get the first row of the training dataset
for column in sample.keys():
if isinstance(sample[column], str): # Find the first text-like column
return column
return None # No valid text column found
# Model training function
def train_model(dataset_url, model_url, epochs):
try:
log_status("🚀 در حال بارگیری مدل...")
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_url, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"]
)
model = get_peft_model(model, lora_config)
model.to(device)
dataset = load_dataset(dataset_url)
# Automatically detect the correct text column
text_column = find_text_column(dataset)
if not text_column:
log_status("❌ خطا: ستون متنی در دیتاست یافت نشد!")
return
def tokenize_function(examples):
return tokenizer(examples[text_column], truncation=True, padding="max_length", max_length=256)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"]
training_args = TrainingArguments(
output_dir="./deepseek_lora_cpu",
evaluation_strategy="epoch",
learning_rate=5e-4,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
num_train_epochs=int(epochs),
save_strategy="epoch",
save_total_limit=2,
logging_dir="./logs",
logging_steps=10,
fp16=False,
gradient_checkpointing=True,
optim="adamw_torch",
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset
)
log_status("🚀 ترینینگ شروع شد...")
for epoch in range(int(epochs)):
log_status(f"🔄 در حال اجرا: Epoch {epoch+1}/{epochs}...")
trainer.train(resume_from_checkpoint=True)
trainer.save_model(f"./deepseek_lora_finetuned_epoch_{epoch+1}")
log_status("✅ ترینینگ کامل شد!")
except Exception as e:
log_status(f"❌ خطا: {str(e)}")
# Start training in a separate process
def start_training(dataset_url, model_url, epochs):
global training_process
if training_process is None or not training_process.is_alive():
training_process = multiprocessing.Process(target=train_model, args=(dataset_url, model_url, epochs))
training_process.start()
return "🚀 ترینینگ شروع شد!"
else:
return "⚠ ترینینگ در حال اجرا است!"
# Function to update the status
def update_status():
return read_status()
# Gradio UI
with gr.Blocks() as app:
gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - نمایش وضعیت لحظه‌ای")
with gr.Row():
dataset_input = gr.Textbox(label="📂 لینک دیتاست (Hugging Face)")
model_input = gr.Textbox(label="🤖 مدل پایه (Hugging Face)")
epochs_input = gr.Number(label="🔄 تعداد Epochs", value=3)
start_button = gr.Button("🚀 شروع ترینینگ")
status_output = gr.Textbox(label="📢 وضعیت ترینینگ", interactive=False)
start_button.click(start_training, inputs=[dataset_input, model_input, epochs_input], outputs=status_output)
status_button = gr.Button("🔄 بروزرسانی وضعیت")
status_button.click(update_status, outputs=status_output)
app.launch()