Spaces:
Sleeping
Sleeping
File size: 2,128 Bytes
e1e315b 36071c5 e1e315b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# training_space/app.py (Training Space Backend)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import subprocess
import os
import uuid
from huggingface_hub import HfApi, HfFolder
app = FastAPI()
# Define the expected payload structure
class TrainingRequest(BaseModel):
task: str # 'generation' or 'classification'
model_params: dict
model_name: str
dataset_content: str # The actual content of the dataset
# Ensure Hugging Face API token is set as an environment variable
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
if not HF_API_TOKEN:
raise ValueError("HF_API_TOKEN environment variable not set.")
# Save the token
HfFolder.save_token(HF_API_TOKEN)
api = HfApi()
@app.post("/train")
def train_model(request: TrainingRequest):
try:
# Create a unique directory for this training session
session_id = str(uuid.uuid4())
session_dir = f"./training_sessions/{session_id}"
os.makedirs(session_dir, exist_ok=True)
# Save the dataset content to a file
dataset_path = os.path.join(session_dir, "dataset.txt")
with open(dataset_path, "w", encoding="utf-8") as f:
f.write(request.dataset_content)
# Prepare the command to run the training script
cmd = [
"python", "train_model.py",
"--task", request.task,
"--model_name", request.model_name,
"--dataset", dataset_path,
"--num_layers", str(request.model_params['num_layers']),
"--attention_heads", str(request.model_params['attention_heads']),
"--hidden_size", str(request.model_params['hidden_size']),
"--vocab_size", str(request.model_params['vocab_size']),
"--sequence_length", str(request.model_params['sequence_length'])
]
# Start the training process as a background task
subprocess.Popen(cmd, cwd=session_dir)
return {"status": "Training started", "session_id": session_id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|