Vishwas1 commited on
Commit
2de0e9b
·
verified ·
1 Parent(s): a61270e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -38,7 +38,8 @@ class TrainingRequest(BaseModel):
38
  task: str # 'generation' or 'classification'
39
  model_params: dict
40
  model_name: str
41
- dataset_content: str # The actual content of the dataset
 
42
 
43
  # Root Endpoint
44
  @app.get("/")
@@ -59,28 +60,27 @@ def train_model(request: TrainingRequest):
59
  session_dir = f"./training_sessions/{session_id}"
60
  os.makedirs(session_dir, exist_ok=True)
61
 
62
- # Save the dataset content to a file
63
- dataset_path = os.path.join(session_dir, "dataset.txt")
64
- with open(dataset_path, "w", encoding="utf-8") as f:
65
- f.write(request.dataset_content)
66
 
67
- # Define the path to train_model.py (assuming it's in the root directory)
68
  TRAIN_MODEL_PATH = os.path.join(os.path.dirname(__file__), "train_model.py")
69
-
 
70
  cmd = [
71
  "python", TRAIN_MODEL_PATH,
72
  "--task", request.task,
73
  "--model_name", request.model_name,
74
- "--dataset", os.path.abspath(dataset_path),
75
  "--num_layers", str(request.model_params.get('num_layers', 12)),
76
  "--attention_heads", str(request.model_params.get('attention_heads', 1)),
77
  "--hidden_size", str(request.model_params.get('hidden_size', 64)),
78
  "--vocab_size", str(request.model_params.get('vocab_size', 30000)),
79
  "--sequence_length", str(request.model_params.get('sequence_length', 512))
80
  ]
81
-
82
- # Start the training process as a background task
83
- subprocess.Popen(cmd, cwd=session_dir)
84
 
85
  logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}")
86
 
@@ -90,6 +90,11 @@ def train_model(request: TrainingRequest):
90
  logging.error(f"Error during training request: {str(e)}")
91
  raise HTTPException(status_code=500, detail=str(e))
92
 
 
 
 
 
 
93
  # Optional: Status Endpoint
94
  @app.get("/status/{session_id}")
95
  def get_status(session_id: str):
 
38
  task: str # 'generation' or 'classification'
39
  model_params: dict
40
  model_name: str
41
+ dataset_name: str # The name of the existing Hugging Face dataset
42
+
43
 
44
  # Root Endpoint
45
  @app.get("/")
 
60
  session_dir = f"./training_sessions/{session_id}"
61
  os.makedirs(session_dir, exist_ok=True)
62
 
63
+ # No need to save dataset content; use dataset_name directly
64
+ dataset_name = request.dataset_name
 
 
65
 
66
+ # Define the absolute path to train_model.py
67
  TRAIN_MODEL_PATH = os.path.join(os.path.dirname(__file__), "train_model.py")
68
+
69
+ # Prepare the command to run the training script with dataset_name
70
  cmd = [
71
  "python", TRAIN_MODEL_PATH,
72
  "--task", request.task,
73
  "--model_name", request.model_name,
74
+ "--dataset_name", dataset_name, # Pass dataset_name instead of dataset file path
75
  "--num_layers", str(request.model_params.get('num_layers', 12)),
76
  "--attention_heads", str(request.model_params.get('attention_heads', 1)),
77
  "--hidden_size", str(request.model_params.get('hidden_size', 64)),
78
  "--vocab_size", str(request.model_params.get('vocab_size', 30000)),
79
  "--sequence_length", str(request.model_params.get('sequence_length', 512))
80
  ]
81
+
82
+ # Start the training process as a background task in the root directory
83
+ subprocess.Popen(cmd, cwd=os.path.dirname(__file__))
84
 
85
  logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}")
86
 
 
90
  logging.error(f"Error during training request: {str(e)}")
91
  raise HTTPException(status_code=500, detail=str(e))
92
 
93
+
94
+ except Exception as e:
95
+ logging.error(f"Error during training request: {str(e)}")
96
+ raise HTTPException(status_code=500, detail=str(e))
97
+
98
  # Optional: Status Endpoint
99
  @app.get("/status/{session_id}")
100
  def get_status(session_id: str):