Shilpaj commited on
Commit
61f0070
·
1 Parent(s): c0a458a

Feat: Complete single model training and inference

Browse files
app.py CHANGED
@@ -8,11 +8,11 @@ import uvicorn
8
  import torch
9
  from scripts.model import Net
10
  from scripts.training.train import train
11
- import json
12
- import os
13
  from pathlib import Path
14
- import asyncio
15
  from fastapi import BackgroundTasks
 
 
 
16
 
17
  app = FastAPI()
18
 
@@ -83,10 +83,8 @@ async def train_model(config: TrainingConfig, background_tasks: BackgroundTasks)
83
  async def websocket_endpoint(websocket: WebSocket):
84
  await websocket.accept()
85
  try:
86
- # Wait for configuration from client
87
  config_data = await websocket.receive_json()
88
 
89
- # Create model instance with the configuration
90
  model = Net(
91
  kernels=[
92
  config_data['block1'],
@@ -95,7 +93,6 @@ async def websocket_endpoint(websocket: WebSocket):
95
  ]
96
  )
97
 
98
- # Create config object
99
  from scripts.training.config import NetworkConfig
100
  config = NetworkConfig()
101
  config.update(
@@ -104,14 +101,14 @@ async def websocket_endpoint(websocket: WebSocket):
104
  block3=config_data['block3'],
105
  optimizer=config_data['optimizer'],
106
  batch_size=config_data['batch_size'],
107
- epochs=1
108
  )
109
 
110
  print(f"Starting training with config: {config_data}")
111
 
112
- # Start training with websocket for real-time updates
113
  try:
114
- await train(model, config, websocket)
 
115
  await websocket.send_json({
116
  "type": "training_complete",
117
  "data": {
@@ -134,6 +131,66 @@ async def websocket_endpoint(websocket: WebSocket):
134
  finally:
135
  print("WebSocket connection closed")
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  # @app.post("/api/train_single")
138
  # async def train_single_model(config: TrainingConfig):
139
  # try:
@@ -165,6 +222,36 @@ async def train_compare_models(config: ComparisonConfig):
165
  except Exception as e:
166
  raise HTTPException(status_code=500, detail=str(e))
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  @app.post("/api/inference")
169
  async def perform_inference(data: dict):
170
  try:
@@ -175,10 +262,23 @@ async def perform_inference(data: dict):
175
  model_path = Path("scripts/training/models") / f"{model_name}.pth"
176
  if not model_path.exists():
177
  raise HTTPException(status_code=404, detail=f"Model not found: {model_path}")
 
 
 
 
 
178
 
179
- # Load model and perform inference
180
- model = Net()
181
- model.load_state_dict(torch.load(str(model_path), map_location=torch.device('cpu')))
 
 
 
 
 
 
 
 
182
  model.eval()
183
 
184
  # Process image data and get prediction
@@ -216,7 +316,15 @@ async def perform_inference(data: dict):
216
  output = model(image_tensor)
217
  prediction = output.argmax(dim=1).item()
218
 
219
- return {"prediction": prediction}
 
 
 
 
 
 
 
 
220
 
221
  except Exception as e:
222
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
 
8
  import torch
9
  from scripts.model import Net
10
  from scripts.training.train import train
 
 
11
  from pathlib import Path
 
12
  from fastapi import BackgroundTasks
13
+ import warnings
14
+
15
+ warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms")
16
 
17
  app = FastAPI()
18
 
 
83
  async def websocket_endpoint(websocket: WebSocket):
84
  await websocket.accept()
85
  try:
 
86
  config_data = await websocket.receive_json()
87
 
 
88
  model = Net(
89
  kernels=[
90
  config_data['block1'],
 
93
  ]
94
  )
95
 
 
96
  from scripts.training.config import NetworkConfig
97
  config = NetworkConfig()
98
  config.update(
 
101
  block3=config_data['block3'],
102
  optimizer=config_data['optimizer'],
103
  batch_size=config_data['batch_size'],
104
+ epochs=config_data['epochs']
105
  )
106
 
107
  print(f"Starting training with config: {config_data}")
108
 
 
109
  try:
110
+ # Pass "single" as model_type for single model training
111
+ await train(model, config, websocket, model_type="single")
112
  await websocket.send_json({
113
  "type": "training_complete",
114
  "data": {
 
131
  finally:
132
  print("WebSocket connection closed")
133
 
134
+ @app.websocket("/ws/compare")
135
+ async def websocket_compare_endpoint(websocket: WebSocket):
136
+ await websocket.accept()
137
+ try:
138
+ data = await websocket.receive_json()
139
+ if data.get("type") == "start_comparison":
140
+ from scripts.training.config import NetworkConfig
141
+
142
+ # Create and train both models
143
+ model1_config = NetworkConfig()
144
+ model2_config = NetworkConfig()
145
+
146
+ # Update configs with received data
147
+ model1_config.update(**data["model1"])
148
+ model2_config.update(**data["model2"])
149
+
150
+ # Create models with respective configurations
151
+ model1 = Net(
152
+ kernels=[
153
+ model1_config.block1,
154
+ model1_config.block2,
155
+ model1_config.block3
156
+ ]
157
+ )
158
+
159
+ model2 = Net(
160
+ kernels=[
161
+ model2_config.block1,
162
+ model2_config.block2,
163
+ model2_config.block3
164
+ ]
165
+ )
166
+
167
+ # Train both models with appropriate model_type
168
+ try:
169
+ await train(model1, model1_config, websocket, model_type="model_1")
170
+ await train(model2, model2_config, websocket, model_type="model_2")
171
+
172
+ await websocket.send_json({
173
+ "type": "comparison_complete",
174
+ "data": {
175
+ "message": "Training completed successfully!"
176
+ }
177
+ })
178
+ except Exception as e:
179
+ print(f"Training error: {str(e)}")
180
+ await websocket.send_json({
181
+ "type": "training_error",
182
+ "data": {
183
+ "message": f"Training failed: {str(e)}"
184
+ }
185
+ })
186
+
187
+ except WebSocketDisconnect:
188
+ print("WebSocket disconnected")
189
+ except Exception as e:
190
+ print(f"WebSocket error: {str(e)}")
191
+ finally:
192
+ print("WebSocket connection closed")
193
+
194
  # @app.post("/api/train_single")
195
  # async def train_single_model(config: TrainingConfig):
196
  # try:
 
222
  except Exception as e:
223
  raise HTTPException(status_code=500, detail=str(e))
224
 
225
+ def parse_model_filename(filename):
226
+ """Extract configuration from model filename"""
227
+ # Example filename: single_arch_32_64_128_opt_adam_batch_64_20240322_123456.pth
228
+ try:
229
+ parts = filename.split('_')
230
+ # Find architecture values
231
+ arch_index = parts.index('arch')
232
+ block1 = int(parts[arch_index + 1])
233
+ block2 = int(parts[arch_index + 2])
234
+ block3 = int(parts[arch_index + 3])
235
+
236
+ # Find optimizer
237
+ opt_index = parts.index('opt')
238
+ optimizer = parts[opt_index + 1]
239
+
240
+ # Find batch size
241
+ batch_index = parts.index('batch')
242
+ batch_size = int(parts[batch_index + 1])
243
+
244
+ return {
245
+ 'block1': block1,
246
+ 'block2': block2,
247
+ 'block3': block3,
248
+ 'optimizer': optimizer,
249
+ 'batch_size': batch_size
250
+ }
251
+ except Exception as e:
252
+ print(f"Error parsing model filename: {e}")
253
+ return None
254
+
255
  @app.post("/api/inference")
256
  async def perform_inference(data: dict):
257
  try:
 
262
  model_path = Path("scripts/training/models") / f"{model_name}.pth"
263
  if not model_path.exists():
264
  raise HTTPException(status_code=404, detail=f"Model not found: {model_path}")
265
+
266
+ # Parse model configuration from filename
267
+ config = parse_model_filename(model_name)
268
+ if not config:
269
+ raise HTTPException(status_code=500, detail="Could not parse model configuration")
270
 
271
+ # Create model with the correct configuration
272
+ model = Net(
273
+ kernels=[
274
+ config['block1'],
275
+ config['block2'],
276
+ config['block3']
277
+ ]
278
+ )
279
+
280
+ # Load model weights
281
+ model.load_state_dict(torch.load(str(model_path), map_location=torch.device('cpu'), weights_only=True))
282
  model.eval()
283
 
284
  # Process image data and get prediction
 
316
  output = model(image_tensor)
317
  prediction = output.argmax(dim=1).item()
318
 
319
+ # Add configuration info to response
320
+ return {
321
+ "prediction": prediction,
322
+ "model_config": {
323
+ "architecture": f"{config['block1']}-{config['block2']}-{config['block3']}",
324
+ "optimizer": config['optimizer'],
325
+ "batch_size": config['batch_size']
326
+ }
327
+ }
328
 
329
  except Exception as e:
330
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
scripts/training/train.py CHANGED
@@ -13,6 +13,17 @@ import shutil
13
  from tqdm import tqdm
14
  import asyncio
15
 
 
 
 
 
 
 
 
 
 
 
 
16
  def download_and_extract_mnist_data():
17
  """Download and extract MNIST dataset from a reliable mirror"""
18
  base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
@@ -107,7 +118,7 @@ def validate(model, test_loader, criterion, device):
107
 
108
  return val_loss, val_acc
109
 
110
- async def train(model, config, websocket=None):
111
  print("\nStarting training...")
112
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
113
  print(f"Using device: {device}")
@@ -140,27 +151,25 @@ async def train(model, config, websocket=None):
140
 
141
  print(f"Dataset loaded. Training samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
142
 
143
- # Initialize optimizer based on config
144
- if config.optimizer.lower() == 'adam':
145
- optimizer = optim.Adam(model.parameters())
146
- else:
147
- optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
148
-
149
- criterion = nn.CrossEntropyLoss()
150
-
151
  print("\nTraining Configuration:")
 
152
  print(f"Optimizer: {config.optimizer}")
153
  print(f"Batch Size: {config.batch_size}")
154
  print(f"Network Architecture: {config.block1}-{config.block2}-{config.block3}")
155
  print("\nStarting training loop...")
156
 
157
  best_val_acc = 0
158
- history = {
159
- 'train_loss': [],
160
- 'train_acc': [],
161
- 'val_loss': [],
162
- 'val_acc': []
163
- }
 
 
 
 
 
164
 
165
  try:
166
  for epoch in range(config.epochs):
@@ -204,10 +213,11 @@ async def train(model, config, websocket=None):
204
  # Send training update through websocket
205
  if websocket:
206
  try:
 
207
  await websocket.send_json({
208
  'type': 'training_update',
209
  'data': {
210
- 'step': batch_idx + epoch * len(train_loader),
211
  'train_loss': current_loss,
212
  'train_acc': current_acc
213
  }
@@ -215,10 +225,6 @@ async def train(model, config, websocket=None):
215
  except Exception as e:
216
  print(f"Error sending websocket update: {e}")
217
 
218
- # Calculate epoch metrics
219
- train_loss = total_loss / len(train_loader)
220
- train_acc = 100. * correct / total
221
-
222
  # Validation phase
223
  model.eval()
224
  val_loss = 0
@@ -240,7 +246,7 @@ async def train(model, config, websocket=None):
240
 
241
  # Print epoch results
242
  print(f"\nEpoch {epoch+1}/{config.epochs} Results:")
243
- print(f"Training Loss: {train_loss:.4f} | Training Accuracy: {train_acc:.2f}%")
244
  print(f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%")
245
 
246
  # Send validation update through websocket
@@ -257,12 +263,17 @@ async def train(model, config, websocket=None):
257
  except Exception as e:
258
  print(f"Error sending websocket update: {e}")
259
 
260
- # Save best model
261
  if val_acc > best_val_acc:
262
  best_val_acc = val_acc
263
  print(f"\nNew best validation accuracy: {val_acc:.2f}%")
264
- print("Saving model...")
265
- torch.save(model.state_dict(), 'best_model.pth')
 
 
 
 
 
266
 
267
  except Exception as e:
268
  print(f"\nError during training: {e}")
@@ -270,4 +281,4 @@ async def train(model, config, websocket=None):
270
 
271
  print("\nTraining completed!")
272
  print(f"Best validation accuracy: {best_val_acc:.2f}%")
273
- return history
 
13
  from tqdm import tqdm
14
  import asyncio
15
 
16
+ def generate_model_filename(config, model_type="single"):
17
+ """Generate a filename based on model configuration
18
+ model_type can be "single", "model_1", or "model_2"
19
+ """
20
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
21
+ arch = f"{config.block1}_{config.block2}_{config.block3}"
22
+ opt = config.optimizer.lower()
23
+ batch = str(config.batch_size)
24
+
25
+ return f"{model_type}_arch_{arch}_opt_{opt}_batch_{batch}_{timestamp}.pth"
26
+
27
  def download_and_extract_mnist_data():
28
  """Download and extract MNIST dataset from a reliable mirror"""
29
  base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
 
118
 
119
  return val_loss, val_acc
120
 
121
+ async def train(model, config, websocket=None, model_type="single"):
122
  print("\nStarting training...")
123
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
  print(f"Using device: {device}")
 
151
 
152
  print(f"Dataset loaded. Training samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
153
 
 
 
 
 
 
 
 
 
154
  print("\nTraining Configuration:")
155
+ print(f"Epochs: {config.epochs}")
156
  print(f"Optimizer: {config.optimizer}")
157
  print(f"Batch Size: {config.batch_size}")
158
  print(f"Network Architecture: {config.block1}-{config.block2}-{config.block3}")
159
  print("\nStarting training loop...")
160
 
161
  best_val_acc = 0
162
+ criterion = nn.CrossEntropyLoss()
163
+
164
+ # Initialize optimizer based on config
165
+ if config.optimizer.lower() == 'adam':
166
+ optimizer = optim.Adam(model.parameters())
167
+ else:
168
+ optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
169
+
170
+ # Create models directory if it doesn't exist
171
+ models_dir = Path("scripts/training/models")
172
+ models_dir.mkdir(parents=True, exist_ok=True)
173
 
174
  try:
175
  for epoch in range(config.epochs):
 
213
  # Send training update through websocket
214
  if websocket:
215
  try:
216
+ step = batch_idx + epoch * len(train_loader)
217
  await websocket.send_json({
218
  'type': 'training_update',
219
  'data': {
220
+ 'step': step,
221
  'train_loss': current_loss,
222
  'train_acc': current_acc
223
  }
 
225
  except Exception as e:
226
  print(f"Error sending websocket update: {e}")
227
 
 
 
 
 
228
  # Validation phase
229
  model.eval()
230
  val_loss = 0
 
246
 
247
  # Print epoch results
248
  print(f"\nEpoch {epoch+1}/{config.epochs} Results:")
249
+ print(f"Training Loss: {current_loss:.4f} | Training Accuracy: {current_acc:.2f}%")
250
  print(f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%")
251
 
252
  # Send validation update through websocket
 
263
  except Exception as e:
264
  print(f"Error sending websocket update: {e}")
265
 
266
+ # Save best model with configuration in filename
267
  if val_acc > best_val_acc:
268
  best_val_acc = val_acc
269
  print(f"\nNew best validation accuracy: {val_acc:.2f}%")
270
+
271
+ # Generate filename with configuration
272
+ model_filename = generate_model_filename(config, model_type)
273
+ model_path = models_dir / model_filename
274
+
275
+ print(f"Saving model as: {model_filename}")
276
+ torch.save(model.state_dict(), model_path)
277
 
278
  except Exception as e:
279
  print(f"\nError during training: {e}")
 
281
 
282
  print("\nTraining completed!")
283
  print(f"Best validation accuracy: {best_val_acc:.2f}%")
284
+ return None
static/js/inference.js CHANGED
@@ -46,8 +46,18 @@ function setupCanvas() {
46
  }
47
 
48
  function clearCanvas() {
 
 
 
 
49
  ctx.fillStyle = "white";
50
  ctx.fillRect(0, 0, canvas.width, canvas.height);
 
 
 
 
 
 
51
  }
52
 
53
  async function predict() {
 
46
  }
47
 
48
  function clearCanvas() {
49
+ const canvas = document.getElementById('drawing-canvas');
50
+ const ctx = canvas.getContext('2d');
51
+
52
+ // Clear the canvas
53
  ctx.fillStyle = "white";
54
  ctx.fillRect(0, 0, canvas.width, canvas.height);
55
+ ctx.beginPath();
56
+
57
+ // Hide and clear prediction result
58
+ const resultDiv = document.getElementById('prediction-result');
59
+ resultDiv.classList.add('hidden');
60
+ resultDiv.innerHTML = '';
61
  }
62
 
63
  async function predict() {
templates/train_compare.html CHANGED
@@ -67,6 +67,14 @@
67
  <option value="128">128</option>
68
  </select>
69
  </div>
 
 
 
 
 
 
 
 
70
  </div>
71
  </div>
72
 
@@ -126,6 +134,14 @@
126
  <option value="128">128</option>
127
  </select>
128
  </div>
 
 
 
 
 
 
 
 
129
  </div>
130
  </div>
131
  </div>
@@ -343,7 +359,8 @@
343
  block2: parseInt(document.getElementById('model1_block2').value),
344
  block3: parseInt(document.getElementById('model1_block3').value),
345
  optimizer: document.getElementById('model1_optimizer').value,
346
- batch_size: parseInt(document.getElementById('model1_batch_size').value)
 
347
  };
348
 
349
  const model2Config = {
@@ -351,7 +368,8 @@
351
  block2: parseInt(document.getElementById('model2_block2').value),
352
  block3: parseInt(document.getElementById('model2_block3').value),
353
  optimizer: document.getElementById('model2_optimizer').value,
354
- batch_size: parseInt(document.getElementById('model2_batch_size').value)
 
355
  };
356
 
357
  // Setup WebSocket connection
 
67
  <option value="128">128</option>
68
  </select>
69
  </div>
70
+ <div class="config-item">
71
+ <label for="model1_epochs">Epochs:</label>
72
+ <select id="model1_epochs" name="epochs">
73
+ <option value="1">1</option>
74
+ <option value="2">2</option>
75
+ <option value="3">3</option>
76
+ </select>
77
+ </div>
78
  </div>
79
  </div>
80
 
 
134
  <option value="128">128</option>
135
  </select>
136
  </div>
137
+ <div class="config-item">
138
+ <label for="model2_epochs">Epochs:</label>
139
+ <select id="model2_epochs" name="epochs">
140
+ <option value="1">1</option>
141
+ <option value="2">2</option>
142
+ <option value="3">3</option>
143
+ </select>
144
+ </div>
145
  </div>
146
  </div>
147
  </div>
 
359
  block2: parseInt(document.getElementById('model1_block2').value),
360
  block3: parseInt(document.getElementById('model1_block3').value),
361
  optimizer: document.getElementById('model1_optimizer').value,
362
+ batch_size: parseInt(document.getElementById('model1_batch_size').value),
363
+ epochs: parseInt(document.getElementById('model1_epochs').value)
364
  };
365
 
366
  const model2Config = {
 
368
  block2: parseInt(document.getElementById('model2_block2').value),
369
  block3: parseInt(document.getElementById('model2_block3').value),
370
  optimizer: document.getElementById('model2_optimizer').value,
371
+ batch_size: parseInt(document.getElementById('model2_batch_size').value),
372
+ epochs: parseInt(document.getElementById('model2_epochs').value)
373
  };
374
 
375
  // Setup WebSocket connection
templates/train_single.html CHANGED
@@ -67,6 +67,14 @@
67
  <option value="128">128</option>
68
  </select>
69
  </div>
 
 
 
 
 
 
 
 
70
  </div>
71
  </div>
72
 
@@ -81,6 +89,13 @@
81
  <div id="lossChart"></div>
82
  <div id="accuracyChart"></div>
83
  </div>
 
 
 
 
 
 
 
84
  </div>
85
 
86
  <script>
@@ -215,7 +230,7 @@
215
  block3: parseInt(document.getElementById('block3').value),
216
  optimizer: document.getElementById('optimizer').value,
217
  batch_size: parseInt(document.getElementById('batch_size').value),
218
- epochs: 1 // Add default epochs value
219
  };
220
  ws.send(JSON.stringify(config));
221
  };
@@ -261,6 +276,8 @@
261
  else if (data.type === 'training_complete') {
262
  alert(data.data.message);
263
  stopTraining();
 
 
264
  }
265
  else if (data.type === 'training_error') {
266
  alert(data.data.message);
@@ -369,6 +386,22 @@
369
  height: 400px;
370
  width: 100%;
371
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  </style>
373
  </body>
374
  </html>
 
67
  <option value="128">128</option>
68
  </select>
69
  </div>
70
+ <div class="config-item">
71
+ <label for="epochs">Epochs:</label>
72
+ <select id="epochs" name="epochs">
73
+ <option value="1">1</option>
74
+ <option value="2">2</option>
75
+ <option value="3">3</option>
76
+ </select>
77
+ </div>
78
  </div>
79
  </div>
80
 
 
89
  <div id="lossChart"></div>
90
  <div id="accuracyChart"></div>
91
  </div>
92
+
93
+ <!-- Inference Controls -->
94
+ <div class="inference-controls" style="display: none;">
95
+ <button id="goToInference" onclick="window.location.href='/inference'" class="inference-button">
96
+ Try Model Inference
97
+ </button>
98
+ </div>
99
  </div>
100
 
101
  <script>
 
230
  block3: parseInt(document.getElementById('block3').value),
231
  optimizer: document.getElementById('optimizer').value,
232
  batch_size: parseInt(document.getElementById('batch_size').value),
233
+ epochs: parseInt(document.getElementById('epochs').value)
234
  };
235
  ws.send(JSON.stringify(config));
236
  };
 
276
  else if (data.type === 'training_complete') {
277
  alert(data.data.message);
278
  stopTraining();
279
+ // Show the inference button
280
+ document.querySelector('.inference-controls').style.display = 'block';
281
  }
282
  else if (data.type === 'training_error') {
283
  alert(data.data.message);
 
386
  height: 400px;
387
  width: 100%;
388
  }
389
+
390
+ .inference-controls {
391
+ margin: 20px 0;
392
+ text-align: center;
393
+ }
394
+
395
+ .inference-button {
396
+ background-color: #28a745;
397
+ padding: 12px 24px;
398
+ font-size: 1.1em;
399
+ transition: background-color 0.3s;
400
+ }
401
+
402
+ .inference-button:hover {
403
+ background-color: #218838;
404
+ }
405
  </style>
406
  </body>
407
  </html>