File size: 3,611 Bytes
981a076
e1e315b
 
 
 
 
36071c5
7da4761
981a076
e1e315b
 
 
981a076
 
 
 
 
 
 
 
 
7da4761
981a076
 
3042d4c
7da4761
 
 
 
 
 
 
 
 
 
e1e315b
 
 
 
 
 
 
981a076
deddd5d
 
 
 
 
 
 
981a076
e1e315b
 
 
deddd5d
981a076
e1e315b
 
 
 
 
 
 
 
 
 
21a5890
 
 
e1e315b
21a5890
e1e315b
 
1a39537
981a076
 
 
 
 
e1e315b
21a5890
e1e315b
 
 
deddd5d
 
e1e315b
 
 
deddd5d
e1e315b
deddd5d
 
 
 
 
 
 
 
 
 
 
 
 
981a076
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# training_space/app.py (FastAPI Backend)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import subprocess
import os
import uuid
from huggingface_hub import HfApi, HfFolder
from fastapi.middleware.cors import CORSMiddleware
import logging

app = FastAPI()

# Configure Logging
logging.basicConfig(
    filename='training.log',
    filemode='a',
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO
)

# CORS Configuration
origins = [
    "https://Vishwas1-LLMBuilderPro.hf.space",  # Replace with your Gradio frontend Space URL
    "http://localhost",  # For local testing
    "https://web.postman.co",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Define the expected payload structure
class TrainingRequest(BaseModel):
    task: str  # 'generation' or 'classification'
    model_params: dict
    model_name: str
    dataset_content: str  # The actual content of the dataset

# Root Endpoint
@app.get("/")
def read_root():
    return {
        "message": "Welcome to the Training Space API!",
        "instructions": "To train a model, send a POST request to /train with the required parameters."
    }

# Train Endpoint
@app.post("/train")
def train_model(request: TrainingRequest):
    try:
        logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}")
        
        # Create a unique directory for this training session
        session_id = str(uuid.uuid4())
        session_dir = f"./training_sessions/{session_id}"
        os.makedirs(session_dir, exist_ok=True)
        
        # Save the dataset content to a file
        dataset_path = os.path.join(session_dir, "dataset.txt")
        with open(dataset_path, "w", encoding="utf-8") as f:
            f.write(request.dataset_content)
        
        # Define the path to train_model.py (assuming it's in the root directory)
        TRAIN_MODEL_PATH = os.path.join(os.path.dirname(__file__), "train_model.py")

        cmd = [
            "python", TRAIN_MODEL_PATH,
            "--task", request.task,
            "--model_name", request.model_name,
            "--dataset", os.path.abspath(dataset_path),
            "--num_layers", str(request.model_params.get('num_layers', 12)),
            "--attention_heads", str(request.model_params.get('attention_heads', 1)),
            "--hidden_size", str(request.model_params.get('hidden_size', 64)),
            "--vocab_size", str(request.model_params.get('vocab_size', 30000)),
            "--sequence_length", str(request.model_params.get('sequence_length', 512))
        ]
      
        # Start the training process as a background task
        subprocess.Popen(cmd, cwd=session_dir)
        
        logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}")
        
        return {"status": "Training started", "session_id": session_id}
    
    except Exception as e:
        logging.error(f"Error during training request: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

# Optional: Status Endpoint
@app.get("/status/{session_id}")
def get_status(session_id: str):
    session_dir = f"./training_sessions/{session_id}"
    log_file = os.path.join(session_dir, "training.log")
    if not os.path.exists(log_file):
        raise HTTPException(status_code=404, detail="Session ID not found.")
    
    with open(log_file, "r", encoding="utf-8") as f:
        logs = f.read()
    
    return {"session_id": session_id, "logs": logs}