File size: 3,264 Bytes
e1e315b
 
 
 
 
 
36071c5
deddd5d
e1e315b
 
 
deddd5d
 
 
 
 
 
 
 
e1e315b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deddd5d
 
 
 
 
 
 
e1e315b
 
 
deddd5d
e1e315b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deddd5d
 
e1e315b
 
 
deddd5d
e1e315b
deddd5d
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# training_space/app.py (Training Space Backend)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import subprocess
import os
import uuid
from huggingface_hub import HfApi, HfFolder
import logging

app = FastAPI()

# Configure logging
logging.basicConfig(
    filename='training.log',
    filemode='a',
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO
)

# Define the expected payload structure
class TrainingRequest(BaseModel):
    task: str  # 'generation' or 'classification'
    model_params: dict
    model_name: str
    dataset_content: str  # The actual content of the dataset

# Ensure Hugging Face API token is set as an environment variable
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
if not HF_API_TOKEN:
    raise ValueError("HF_API_TOKEN environment variable not set.")

# Save the token
HfFolder.save_token(HF_API_TOKEN)
api = HfApi()

@app.get("/")
def read_root():
    return {
        "message": "Welcome to the Training Space API!",
        "instructions": "To train a model, send a POST request to /train with the required parameters."
    }

@app.post("/train")
def train_model(request: TrainingRequest):
    try:
        logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}")
        # Create a unique directory for this training session
        session_id = str(uuid.uuid4())
        session_dir = f"./training_sessions/{session_id}"
        os.makedirs(session_dir, exist_ok=True)
        
        # Save the dataset content to a file
        dataset_path = os.path.join(session_dir, "dataset.txt")
        with open(dataset_path, "w", encoding="utf-8") as f:
            f.write(request.dataset_content)
        
        # Prepare the command to run the training script
        cmd = [
            "python", "train_model.py",
            "--task", request.task,
            "--model_name", request.model_name,
            "--dataset", dataset_path,
            "--num_layers", str(request.model_params['num_layers']),
            "--attention_heads", str(request.model_params['attention_heads']),
            "--hidden_size", str(request.model_params['hidden_size']),
            "--vocab_size", str(request.model_params['vocab_size']),
            "--sequence_length", str(request.model_params['sequence_length'])
        ]
        
        # Start the training process as a background task
        subprocess.Popen(cmd, cwd=session_dir)
        
        logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}")
        
        return {"status": "Training started", "session_id": session_id}
    
    except Exception as e:
        logging.error(f"Error during training request: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

# Optional: Status Endpoint
@app.get("/status/{session_id}")
def get_status(session_id: str):
    session_dir = f"./training_sessions/{session_id}"
    log_file = os.path.join(session_dir, "training.log")
    if not os.path.exists(log_file):
        raise HTTPException(status_code=404, detail="Session ID not found.")
    
    with open(log_file, "r", encoding="utf-8") as f:
        logs = f.read()
    
    return {"session_id": session_id, "logs": logs}