Vishwas1 commited on
Commit
981a076
·
verified ·
1 Parent(s): 8e26074

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -27
app.py CHANGED
@@ -1,19 +1,27 @@
1
- # training_space/app.py (Training Space Backend)
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  import subprocess
5
  import os
6
  import uuid
7
  from huggingface_hub import HfApi, HfFolder
8
- import logging
9
  from fastapi.middleware.cors import CORSMiddleware
 
10
 
11
  app = FastAPI()
12
 
13
- # Configure CORS
 
 
 
 
 
 
 
 
14
  origins = [
15
- "https://huggingface.co/spaces/Vishwas1/LLMBuilderPro", # Replace with your Gradio frontend Space URL
16
- "http://localhost", # If testing locally
17
  "https://web.postman.co",
18
  ]
19
 
@@ -25,14 +33,6 @@ app.add_middleware(
25
  allow_headers=["*"],
26
  )
27
 
28
- # Configure logging
29
- logging.basicConfig(
30
- filename='training.log',
31
- filemode='a',
32
- format='%(asctime)s - %(levelname)s - %(message)s',
33
- level=logging.INFO
34
- )
35
-
36
  # Define the expected payload structure
37
  class TrainingRequest(BaseModel):
38
  task: str # 'generation' or 'classification'
@@ -40,15 +40,7 @@ class TrainingRequest(BaseModel):
40
  model_name: str
41
  dataset_content: str # The actual content of the dataset
42
 
43
- # Ensure Hugging Face API token is set as an environment variable
44
- HF_API_TOKEN = os.getenv("HF_API_TOKEN")
45
- if not HF_API_TOKEN:
46
- raise ValueError("HF_API_TOKEN environment variable not set.")
47
-
48
- # Save the token
49
- HfFolder.save_token(HF_API_TOKEN)
50
- api = HfApi()
51
-
52
  @app.get("/")
53
  def read_root():
54
  return {
@@ -56,10 +48,12 @@ def read_root():
56
  "instructions": "To train a model, send a POST request to /train with the required parameters."
57
  }
58
 
 
59
  @app.post("/train")
60
  def train_model(request: TrainingRequest):
61
  try:
62
  logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}")
 
63
  # Create a unique directory for this training session
64
  session_id = str(uuid.uuid4())
65
  session_dir = f"./training_sessions/{session_id}"
@@ -76,11 +70,11 @@ def train_model(request: TrainingRequest):
76
  "--task", request.task,
77
  "--model_name", request.model_name,
78
  "--dataset", dataset_path,
79
- "--num_layers", str(request.model_params['num_layers']),
80
- "--attention_heads", str(request.model_params['attention_heads']),
81
- "--hidden_size", str(request.model_params['hidden_size']),
82
- "--vocab_size", str(request.model_params['vocab_size']),
83
- "--sequence_length", str(request.model_params['sequence_length'])
84
  ]
85
 
86
  # Start the training process as a background task
@@ -106,3 +100,4 @@ def get_status(session_id: str):
106
  logs = f.read()
107
 
108
  return {"session_id": session_id, "logs": logs}
 
 
1
+ # training_space/app.py (FastAPI Backend)
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  import subprocess
5
  import os
6
  import uuid
7
  from huggingface_hub import HfApi, HfFolder
 
8
  from fastapi.middleware.cors import CORSMiddleware
9
+ import logging
10
 
11
  app = FastAPI()
12
 
13
+ # Configure Logging
14
+ logging.basicConfig(
15
+ filename='training.log',
16
+ filemode='a',
17
+ format='%(asctime)s - %(levelname)s - %(message)s',
18
+ level=logging.INFO
19
+ )
20
+
21
+ # CORS Configuration
22
  origins = [
23
+ "https://Vishwas1-LLMBuilderPro.hf.space", # Replace with your Gradio frontend Space URL
24
+ "http://localhost", # For local testing
25
  "https://web.postman.co",
26
  ]
27
 
 
33
  allow_headers=["*"],
34
  )
35
 
 
 
 
 
 
 
 
 
36
  # Define the expected payload structure
37
  class TrainingRequest(BaseModel):
38
  task: str # 'generation' or 'classification'
 
40
  model_name: str
41
  dataset_content: str # The actual content of the dataset
42
 
43
+ # Root Endpoint
 
 
 
 
 
 
 
 
44
  @app.get("/")
45
  def read_root():
46
  return {
 
48
  "instructions": "To train a model, send a POST request to /train with the required parameters."
49
  }
50
 
51
+ # Train Endpoint
52
  @app.post("/train")
53
  def train_model(request: TrainingRequest):
54
  try:
55
  logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}")
56
+
57
  # Create a unique directory for this training session
58
  session_id = str(uuid.uuid4())
59
  session_dir = f"./training_sessions/{session_id}"
 
70
  "--task", request.task,
71
  "--model_name", request.model_name,
72
  "--dataset", dataset_path,
73
+ "--num_layers", str(request.model_params.get('num_layers', 12)),
74
+ "--attention_heads", str(request.model_params.get('attention_heads', 1)),
75
+ "--hidden_size", str(request.model_params.get('hidden_size', 64)),
76
+ "--vocab_size", str(request.model_params.get('vocab_size', 30000)),
77
+ "--sequence_length", str(request.model_params.get('sequence_length', 512))
78
  ]
79
 
80
  # Start the training process as a background task
 
100
  logs = f.read()
101
 
102
  return {"session_id": session_id, "logs": logs}
103
+