LLMTrainingPro / app.py
Vishwas1's picture
Update app.py
deddd5d verified
raw
history blame
3.26 kB
# training_space/app.py (Training Space Backend)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import subprocess
import os
import uuid
from huggingface_hub import HfApi, HfFolder
import logging
app = FastAPI()
# Configure logging
logging.basicConfig(
filename='training.log',
filemode='a',
format='%(asctime)s - %(levelname)s - %(message)s',
level=logging.INFO
)
# Define the expected payload structure
class TrainingRequest(BaseModel):
task: str # 'generation' or 'classification'
model_params: dict
model_name: str
dataset_content: str # The actual content of the dataset
# Ensure Hugging Face API token is set as an environment variable
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
if not HF_API_TOKEN:
raise ValueError("HF_API_TOKEN environment variable not set.")
# Save the token
HfFolder.save_token(HF_API_TOKEN)
api = HfApi()
@app.get("/")
def read_root():
return {
"message": "Welcome to the Training Space API!",
"instructions": "To train a model, send a POST request to /train with the required parameters."
}
@app.post("/train")
def train_model(request: TrainingRequest):
try:
logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}")
# Create a unique directory for this training session
session_id = str(uuid.uuid4())
session_dir = f"./training_sessions/{session_id}"
os.makedirs(session_dir, exist_ok=True)
# Save the dataset content to a file
dataset_path = os.path.join(session_dir, "dataset.txt")
with open(dataset_path, "w", encoding="utf-8") as f:
f.write(request.dataset_content)
# Prepare the command to run the training script
cmd = [
"python", "train_model.py",
"--task", request.task,
"--model_name", request.model_name,
"--dataset", dataset_path,
"--num_layers", str(request.model_params['num_layers']),
"--attention_heads", str(request.model_params['attention_heads']),
"--hidden_size", str(request.model_params['hidden_size']),
"--vocab_size", str(request.model_params['vocab_size']),
"--sequence_length", str(request.model_params['sequence_length'])
]
# Start the training process as a background task
subprocess.Popen(cmd, cwd=session_dir)
logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}")
return {"status": "Training started", "session_id": session_id}
except Exception as e:
logging.error(f"Error during training request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Optional: Status Endpoint
@app.get("/status/{session_id}")
def get_status(session_id: str):
session_dir = f"./training_sessions/{session_id}"
log_file = os.path.join(session_dir, "training.log")
if not os.path.exists(log_file):
raise HTTPException(status_code=404, detail="Session ID not found.")
with open(log_file, "r", encoding="utf-8") as f:
logs = f.read()
return {"session_id": session_id, "logs": logs}