File size: 3,808 Bytes
981a076
e1e315b
 
 
 
 
36071c5
7da4761
981a076
e1e315b
 
 
981a076
 
 
 
 
 
 
 
 
7da4761
981a076
 
3042d4c
7da4761
 
 
 
 
 
 
 
 
 
e1e315b
 
 
 
 
2de0e9b
 
e1e315b
981a076
deddd5d
 
 
 
 
 
 
981a076
e1e315b
 
 
deddd5d
981a076
e1e315b
 
 
 
 
2de0e9b
 
e1e315b
2de0e9b
21a5890
2de0e9b
 
e1e315b
21a5890
e1e315b
 
2de0e9b
981a076
 
 
 
 
e1e315b
2de0e9b
 
 
e1e315b
deddd5d
 
e1e315b
 
 
deddd5d
e1e315b
deddd5d
2de0e9b
 
 
 
 
deddd5d
 
 
 
 
 
 
 
 
 
 
 
981a076
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# training_space/app.py (FastAPI Backend)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import subprocess
import os
import uuid
from huggingface_hub import HfApi, HfFolder
from fastapi.middleware.cors import CORSMiddleware
import logging

app = FastAPI()

# Configure Logging
logging.basicConfig(
    filename='training.log',
    filemode='a',
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO
)

# CORS Configuration
origins = [
    "https://Vishwas1-LLMBuilderPro.hf.space",  # Replace with your Gradio frontend Space URL
    "http://localhost",  # For local testing
    "https://web.postman.co",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Define the expected payload structure
class TrainingRequest(BaseModel):
    task: str  # 'generation' or 'classification'
    model_params: dict
    model_name: str
    dataset_name: str  # The name of the existing Hugging Face dataset


# Root Endpoint
@app.get("/")
def read_root():
    return {
        "message": "Welcome to the Training Space API!",
        "instructions": "To train a model, send a POST request to /train with the required parameters."
    }

# Train Endpoint
@app.post("/train")
def train_model(request: TrainingRequest):
    try:
        logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}")
        
        # Create a unique directory for this training session
        session_id = str(uuid.uuid4())
        session_dir = f"./training_sessions/{session_id}"
        os.makedirs(session_dir, exist_ok=True)
        
        # No need to save dataset content; use dataset_name directly
        dataset_name = request.dataset_name
        
        # Define the absolute path to train_model.py
        TRAIN_MODEL_PATH = os.path.join(os.path.dirname(__file__), "train_model.py")
        
        # Prepare the command to run the training script with dataset_name
        cmd = [
            "python", TRAIN_MODEL_PATH,
            "--task", request.task,
            "--model_name", request.model_name,
            "--dataset_name", dataset_name,  # Pass dataset_name instead of dataset file path
            "--num_layers", str(request.model_params.get('num_layers', 12)),
            "--attention_heads", str(request.model_params.get('attention_heads', 1)),
            "--hidden_size", str(request.model_params.get('hidden_size', 64)),
            "--vocab_size", str(request.model_params.get('vocab_size', 30000)),
            "--sequence_length", str(request.model_params.get('sequence_length', 512))
        ]
        
        # Start the training process as a background task in the root directory
        subprocess.Popen(cmd, cwd=os.path.dirname(__file__))
        
        logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}")
        
        return {"status": "Training started", "session_id": session_id}
    
    except Exception as e:
        logging.error(f"Error during training request: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

    
    except Exception as e:
        logging.error(f"Error during training request: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

# Optional: Status Endpoint
@app.get("/status/{session_id}")
def get_status(session_id: str):
    session_dir = f"./training_sessions/{session_id}"
    log_file = os.path.join(session_dir, "training.log")
    if not os.path.exists(log_file):
        raise HTTPException(status_code=404, detail="Session ID not found.")
    
    with open(log_file, "r", encoding="utf-8") as f:
        logs = f.read()
    
    return {"session_id": session_id, "logs": logs}