Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import json | |
import torch | |
import pandas as pd | |
import gradio as gr | |
from sqlalchemy import create_engine, text | |
from transformers import ( | |
TrainingArguments, | |
Trainer, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
DataCollatorForLanguageModeling | |
) | |
from datasets import Dataset | |
from peft import ( | |
prepare_model_for_kbit_training, | |
LoraConfig, | |
get_peft_model | |
) | |
from datetime import datetime | |
# Constants - Modified for HF Spaces | |
MODEL_NAME = "deepseek-ai/DeepSeek-R1" | |
OUTPUT_DIR = "/tmp/finetuned_models" # Using /tmp for HF Spaces | |
LOGS_DIR = "/tmp/training_logs" # Using /tmp for HF Spaces | |
class TrainingInterface: | |
def __init__(self): | |
self.current_status = "Idle" | |
self.progress = 0 | |
self.is_training = False | |
def get_database_url(self): | |
"""Get database URL from HF Space secrets""" | |
database_url = os.environ.get('DATABASE_URL') | |
if not database_url: | |
raise Exception("DATABASE_URL not found in environment variables") | |
return database_url | |
def fetch_training_data(self, progress=gr.Progress()): | |
"""Fetch training data from database""" | |
try: | |
database_url = self.get_database_url() | |
engine = create_engine(database_url) | |
progress(0, desc="Connecting to database...") | |
with engine.connect() as conn: | |
result = conn.execute(text("SELECT COUNT(*) FROM bents")) | |
total_rows = result.scalar() | |
query = text("SELECT chunk_id, text FROM bents") | |
df = pd.read_sql_query(query, conn) | |
progress(0.5, desc="Data fetched successfully") | |
return df | |
except Exception as e: | |
raise gr.Error(f"Database error: {str(e)}") | |
def prepare_training_data(self, df, progress=gr.Progress()): | |
"""Convert DataFrame into training format""" | |
formatted_data = [] | |
try: | |
total_rows = len(df) | |
for idx, row in enumerate(df.iterrows()): | |
progress(idx/total_rows, desc="Preparing training data...") | |
_, row_data = row | |
chunk_id = str(row_data['chunk_id']).strip() | |
text = str(row_data['text']).strip() | |
if chunk_id and text: | |
formatted_text = f"User: {chunk_id}\nAssistant: {text}" | |
formatted_data.append({"text": formatted_text}) | |
if not formatted_data: | |
raise ValueError("No valid training data found") | |
return formatted_data | |
except Exception as e: | |
raise gr.Error(f"Data preparation error: {str(e)}") | |
def stop_training(self): | |
"""Stop the training process""" | |
self.is_training = False | |
return "Training stopped by user." | |
def train_model( | |
self, | |
learning_rate=2e-4, | |
num_epochs=3, | |
batch_size=4, | |
progress=gr.Progress() | |
): | |
"""Main training function""" | |
try: | |
self.is_training = True | |
# Create directories in /tmp for HF Spaces | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
specific_output_dir = os.path.join(OUTPUT_DIR, f"run_{timestamp}") | |
os.makedirs(specific_output_dir, exist_ok=True) | |
os.makedirs(LOGS_DIR, exist_ok=True) | |
# Data preparation | |
progress(0.1, desc="Fetching data...") | |
if not self.is_training: | |
return "Training cancelled." | |
df = self.fetch_training_data() | |
formatted_data = self.prepare_training_data(df) | |
# Model initialization | |
progress(0.2, desc="Loading model...") | |
if not self.is_training: | |
return "Training cancelled." | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
load_in_8bit=True, | |
device_map="auto" # Important for HF Spaces GPU allocation | |
) | |
# LoRA configuration | |
progress(0.3, desc="Setting up LoRA...") | |
if not self.is_training: | |
return "Training cancelled." | |
lora_config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
target_modules=[ | |
"q_proj", "k_proj", "v_proj", "o_proj", | |
"gate_proj", "up_proj", "down_proj" | |
], | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM" | |
) | |
model = prepare_model_for_kbit_training(model) | |
model = get_peft_model(model, lora_config) | |
# Training setup | |
progress(0.4, desc="Configuring training...") | |
if not self.is_training: | |
return "Training cancelled." | |
training_args = TrainingArguments( | |
output_dir=specific_output_dir, | |
num_train_epochs=num_epochs, | |
per_device_train_batch_size=batch_size, | |
learning_rate=learning_rate, | |
fp16=True, | |
gradient_accumulation_steps=8, | |
gradient_checkpointing=True, | |
logging_dir=os.path.join(LOGS_DIR, f"run_{timestamp}"), | |
logging_steps=10, | |
save_strategy="epoch", | |
evaluation_strategy="epoch", | |
save_total_limit=2, | |
remove_unused_columns=False, # Important for HF Spaces | |
) | |
dataset = Dataset.from_dict({ | |
'text': [item['text'] for item in formatted_data] | |
}) | |
data_collator = DataCollatorForLanguageModeling( | |
tokenizer=tokenizer, | |
mlm=False | |
) | |
# Custom progress callback | |
class ProgressCallback(gr.Progress): | |
def __init__(self, progress_callback, training_interface): | |
self.progress_callback = progress_callback | |
self.training_interface = training_interface | |
def on_train_begin(self, args, state, control, **kwargs): | |
if not self.training_interface.is_training: | |
control.should_training_stop = True | |
self.progress_callback(0.5, desc="Training started...") | |
def on_epoch_begin(self, args, state, control, **kwargs): | |
if not self.training_interface.is_training: | |
control.should_training_stop = True | |
epoch_progress = (state.epoch / args.num_train_epochs) | |
total_progress = 0.5 + (epoch_progress * 0.4) | |
self.progress_callback(total_progress, | |
desc=f"Training epoch {state.epoch + 1}/{args.num_train_epochs}...") | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset, | |
data_collator=data_collator, | |
callbacks=[ProgressCallback(progress, self)] | |
) | |
if not self.is_training: | |
return "Training cancelled." | |
trainer.train() | |
if not self.is_training: | |
return "Training cancelled." | |
# Save model | |
progress(0.9, desc="Saving model...") | |
trainer.save_model() | |
tokenizer.save_pretrained(specific_output_dir) | |
progress(1.0, desc="Training completed!") | |
return f"Training completed! Model saved in {specific_output_dir}" | |
except Exception as e: | |
self.is_training = False | |
raise gr.Error(f"Training error: {str(e)}") | |
def create_training_interface(): | |
"""Create Gradio interface""" | |
interface = TrainingInterface() | |
with gr.Blocks(title="DeepSeek Model Training Interface") as app: | |
gr.Markdown("# DeepSeek Model Fine-tuning Interface") | |
with gr.Row(): | |
with gr.Column(): | |
learning_rate = gr.Slider( | |
minimum=1e-5, | |
maximum=1e-3, | |
value=2e-4, | |
label="Learning Rate" | |
) | |
num_epochs = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=1, | |
label="Number of Epochs" | |
) | |
batch_size = gr.Slider( | |
minimum=1, | |
maximum=8, | |
value=4, | |
step=1, | |
label="Batch Size" | |
) | |
with gr.Row(): | |
train_button = gr.Button("Start Training", variant="primary") | |
stop_button = gr.Button("Stop Training", variant="secondary") | |
output_text = gr.Textbox( | |
label="Training Status", | |
placeholder="Training status will appear here...", | |
lines=10 | |
) | |
train_button.click( | |
fn=interface.train_model, | |
inputs=[learning_rate, num_epochs, batch_size], | |
outputs=output_text | |
) | |
stop_button.click( | |
fn=interface.stop_training, | |
inputs=[], | |
outputs=output_text | |
) | |
return app | |
if __name__ == "__main__": | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
os.makedirs(LOGS_DIR, exist_ok=True) | |
app = create_training_interface() | |
app.launch() |