Vishwas1 commited on
Commit
e1e315b
·
verified ·
1 Parent(s): 8da495a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training_space/app.py (Training Space Backend)
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ import subprocess
5
+ import os
6
+ import uuid
7
+ from transformers import HfApi, HfFolder
8
+
9
+ app = FastAPI()
10
+
11
+ # Define the expected payload structure
12
+ class TrainingRequest(BaseModel):
13
+ task: str # 'generation' or 'classification'
14
+ model_params: dict
15
+ model_name: str
16
+ dataset_content: str # The actual content of the dataset
17
+
18
+ # Ensure Hugging Face API token is set as an environment variable
19
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
20
+ if not HF_API_TOKEN:
21
+ raise ValueError("HF_API_TOKEN environment variable not set.")
22
+
23
+ # Save the token
24
+ HfFolder.save_token(HF_API_TOKEN)
25
+ api = HfApi()
26
+
27
+ @app.post("/train")
28
+ def train_model(request: TrainingRequest):
29
+ try:
30
+ # Create a unique directory for this training session
31
+ session_id = str(uuid.uuid4())
32
+ session_dir = f"./training_sessions/{session_id}"
33
+ os.makedirs(session_dir, exist_ok=True)
34
+
35
+ # Save the dataset content to a file
36
+ dataset_path = os.path.join(session_dir, "dataset.txt")
37
+ with open(dataset_path, "w", encoding="utf-8") as f:
38
+ f.write(request.dataset_content)
39
+
40
+ # Prepare the command to run the training script
41
+ cmd = [
42
+ "python", "train_model.py",
43
+ "--task", request.task,
44
+ "--model_name", request.model_name,
45
+ "--dataset", dataset_path,
46
+ "--num_layers", str(request.model_params['num_layers']),
47
+ "--attention_heads", str(request.model_params['attention_heads']),
48
+ "--hidden_size", str(request.model_params['hidden_size']),
49
+ "--vocab_size", str(request.model_params['vocab_size']),
50
+ "--sequence_length", str(request.model_params['sequence_length'])
51
+ ]
52
+
53
+ # Start the training process as a background task
54
+ subprocess.Popen(cmd, cwd=session_dir)
55
+
56
+ return {"status": "Training started", "session_id": session_id}
57
+
58
+ except Exception as e:
59
+ raise HTTPException(status_code=500, detail=str(e))