Spaces:
Sleeping
Sleeping
# training_space/app.py (Training Space Backend) | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import subprocess | |
import os | |
import uuid | |
from huggingface_hub import HfApi, HfFolder | |
import logging | |
app = FastAPI() | |
# Configure logging | |
logging.basicConfig( | |
filename='training.log', | |
filemode='a', | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
level=logging.INFO | |
) | |
# Define the expected payload structure | |
class TrainingRequest(BaseModel): | |
task: str # 'generation' or 'classification' | |
model_params: dict | |
model_name: str | |
dataset_content: str # The actual content of the dataset | |
# Ensure Hugging Face API token is set as an environment variable | |
HF_API_TOKEN = os.getenv("HF_API_TOKEN") | |
if not HF_API_TOKEN: | |
raise ValueError("HF_API_TOKEN environment variable not set.") | |
# Save the token | |
HfFolder.save_token(HF_API_TOKEN) | |
api = HfApi() | |
def read_root(): | |
return { | |
"message": "Welcome to the Training Space API!", | |
"instructions": "To train a model, send a POST request to /train with the required parameters." | |
} | |
def train_model(request: TrainingRequest): | |
try: | |
logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}") | |
# Create a unique directory for this training session | |
session_id = str(uuid.uuid4()) | |
session_dir = f"./training_sessions/{session_id}" | |
os.makedirs(session_dir, exist_ok=True) | |
# Save the dataset content to a file | |
dataset_path = os.path.join(session_dir, "dataset.txt") | |
with open(dataset_path, "w", encoding="utf-8") as f: | |
f.write(request.dataset_content) | |
# Prepare the command to run the training script | |
cmd = [ | |
"python", "train_model.py", | |
"--task", request.task, | |
"--model_name", request.model_name, | |
"--dataset", dataset_path, | |
"--num_layers", str(request.model_params['num_layers']), | |
"--attention_heads", str(request.model_params['attention_heads']), | |
"--hidden_size", str(request.model_params['hidden_size']), | |
"--vocab_size", str(request.model_params['vocab_size']), | |
"--sequence_length", str(request.model_params['sequence_length']) | |
] | |
# Start the training process as a background task | |
subprocess.Popen(cmd, cwd=session_dir) | |
logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}") | |
return {"status": "Training started", "session_id": session_id} | |
except Exception as e: | |
logging.error(f"Error during training request: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Optional: Status Endpoint | |
def get_status(session_id: str): | |
session_dir = f"./training_sessions/{session_id}" | |
log_file = os.path.join(session_dir, "training.log") | |
if not os.path.exists(log_file): | |
raise HTTPException(status_code=404, detail="Session ID not found.") | |
with open(log_file, "r", encoding="utf-8") as f: | |
logs = f.read() | |
return {"session_id": session_id, "logs": logs} | |