# training_space/app.py (FastAPI Backend) from fastapi import FastAPI, HTTPException from pydantic import BaseModel import subprocess import os import uuid from huggingface_hub import HfApi, HfFolder from fastapi.middleware.cors import CORSMiddleware import logging app = FastAPI() # Configure Logging logging.basicConfig( filename='training.log', filemode='a', format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO ) # CORS Configuration origins = [ "https://Vishwas1-LLMBuilderPro.hf.space", # Replace with your Gradio frontend Space URL "http://localhost", # For local testing "https://web.postman.co", ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Define the expected payload structure class TrainingRequest(BaseModel): task: str # 'generation' or 'classification' model_params: dict model_name: str dataset_name: str # The name of the existing Hugging Face dataset # Root Endpoint @app.get("/") def read_root(): return { "message": "Welcome to the Training Space API!", "instructions": "To train a model, send a POST request to /train with the required parameters." } # Train Endpoint @app.post("/train") def train_model(request: TrainingRequest): try: logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}") # Create a unique directory for this training session session_id = str(uuid.uuid4()) session_dir = f"./training_sessions/{session_id}" os.makedirs(session_dir, exist_ok=True) # No need to save dataset content; use dataset_name directly dataset_name = request.dataset_name # Define the absolute path to train_model.py TRAIN_MODEL_PATH = os.path.join(os.path.dirname(__file__), "train_model.py") # Prepare the command to run the training script with dataset_name cmd = [ "python", TRAIN_MODEL_PATH, "--task", request.task, "--model_name", request.model_name, "--dataset_name", dataset_name, # Pass dataset_name instead of dataset file path "--num_layers", str(request.model_params.get('num_layers', 12)), "--attention_heads", str(request.model_params.get('attention_heads', 1)), "--hidden_size", str(request.model_params.get('hidden_size', 64)), "--vocab_size", str(request.model_params.get('vocab_size', 30000)), "--sequence_length", str(request.model_params.get('sequence_length', 512)) ] # Start the training process as a background task in the root directory subprocess.Popen(cmd, cwd=os.path.dirname(__file__)) logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}") return {"status": "Training started", "session_id": session_id} except Exception as e: logging.error(f"Error during training request: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) except Exception as e: logging.error(f"Error during training request: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # Optional: Status Endpoint @app.get("/status/{session_id}") def get_status(session_id: str): session_dir = f"./training_sessions/{session_id}" log_file = os.path.join(session_dir, "training.log") if not os.path.exists(log_file): raise HTTPException(status_code=404, detail="Session ID not found.") with open(log_file, "r", encoding="utf-8") as f: logs = f.read() return {"session_id": session_id, "logs": logs}