Spaces:
Sleeping
Sleeping
File size: 3,611 Bytes
981a076 e1e315b 36071c5 7da4761 981a076 e1e315b 981a076 7da4761 981a076 3042d4c 7da4761 e1e315b 981a076 deddd5d 981a076 e1e315b deddd5d 981a076 e1e315b 21a5890 e1e315b 21a5890 e1e315b 1a39537 981a076 e1e315b 21a5890 e1e315b deddd5d e1e315b deddd5d e1e315b deddd5d 981a076 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
# training_space/app.py (FastAPI Backend)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import subprocess
import os
import uuid
from huggingface_hub import HfApi, HfFolder
from fastapi.middleware.cors import CORSMiddleware
import logging
app = FastAPI()
# Configure Logging
logging.basicConfig(
filename='training.log',
filemode='a',
format='%(asctime)s - %(levelname)s - %(message)s',
level=logging.INFO
)
# CORS Configuration
origins = [
"https://Vishwas1-LLMBuilderPro.hf.space", # Replace with your Gradio frontend Space URL
"http://localhost", # For local testing
"https://web.postman.co",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define the expected payload structure
class TrainingRequest(BaseModel):
task: str # 'generation' or 'classification'
model_params: dict
model_name: str
dataset_content: str # The actual content of the dataset
# Root Endpoint
@app.get("/")
def read_root():
return {
"message": "Welcome to the Training Space API!",
"instructions": "To train a model, send a POST request to /train with the required parameters."
}
# Train Endpoint
@app.post("/train")
def train_model(request: TrainingRequest):
try:
logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}")
# Create a unique directory for this training session
session_id = str(uuid.uuid4())
session_dir = f"./training_sessions/{session_id}"
os.makedirs(session_dir, exist_ok=True)
# Save the dataset content to a file
dataset_path = os.path.join(session_dir, "dataset.txt")
with open(dataset_path, "w", encoding="utf-8") as f:
f.write(request.dataset_content)
# Define the path to train_model.py (assuming it's in the root directory)
TRAIN_MODEL_PATH = os.path.join(os.path.dirname(__file__), "train_model.py")
cmd = [
"python", TRAIN_MODEL_PATH,
"--task", request.task,
"--model_name", request.model_name,
"--dataset", os.path.abspath(dataset_path),
"--num_layers", str(request.model_params.get('num_layers', 12)),
"--attention_heads", str(request.model_params.get('attention_heads', 1)),
"--hidden_size", str(request.model_params.get('hidden_size', 64)),
"--vocab_size", str(request.model_params.get('vocab_size', 30000)),
"--sequence_length", str(request.model_params.get('sequence_length', 512))
]
# Start the training process as a background task
subprocess.Popen(cmd, cwd=session_dir)
logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}")
return {"status": "Training started", "session_id": session_id}
except Exception as e:
logging.error(f"Error during training request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Optional: Status Endpoint
@app.get("/status/{session_id}")
def get_status(session_id: str):
session_dir = f"./training_sessions/{session_id}"
log_file = os.path.join(session_dir, "training.log")
if not os.path.exists(log_file):
raise HTTPException(status_code=404, detail="Session ID not found.")
with open(log_file, "r", encoding="utf-8") as f:
logs = f.read()
return {"session_id": session_id, "logs": logs}
|