Spaces:
Sleeping
Sleeping
Feat: Complete single model training and inference
Browse files- app.py +121 -13
- scripts/training/train.py +36 -25
- static/js/inference.js +10 -0
- templates/train_compare.html +20 -2
- templates/train_single.html +34 -1
app.py
CHANGED
@@ -8,11 +8,11 @@ import uvicorn
|
|
8 |
import torch
|
9 |
from scripts.model import Net
|
10 |
from scripts.training.train import train
|
11 |
-
import json
|
12 |
-
import os
|
13 |
from pathlib import Path
|
14 |
-
import asyncio
|
15 |
from fastapi import BackgroundTasks
|
|
|
|
|
|
|
16 |
|
17 |
app = FastAPI()
|
18 |
|
@@ -83,10 +83,8 @@ async def train_model(config: TrainingConfig, background_tasks: BackgroundTasks)
|
|
83 |
async def websocket_endpoint(websocket: WebSocket):
|
84 |
await websocket.accept()
|
85 |
try:
|
86 |
-
# Wait for configuration from client
|
87 |
config_data = await websocket.receive_json()
|
88 |
|
89 |
-
# Create model instance with the configuration
|
90 |
model = Net(
|
91 |
kernels=[
|
92 |
config_data['block1'],
|
@@ -95,7 +93,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
95 |
]
|
96 |
)
|
97 |
|
98 |
-
# Create config object
|
99 |
from scripts.training.config import NetworkConfig
|
100 |
config = NetworkConfig()
|
101 |
config.update(
|
@@ -104,14 +101,14 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
104 |
block3=config_data['block3'],
|
105 |
optimizer=config_data['optimizer'],
|
106 |
batch_size=config_data['batch_size'],
|
107 |
-
epochs=
|
108 |
)
|
109 |
|
110 |
print(f"Starting training with config: {config_data}")
|
111 |
|
112 |
-
# Start training with websocket for real-time updates
|
113 |
try:
|
114 |
-
|
|
|
115 |
await websocket.send_json({
|
116 |
"type": "training_complete",
|
117 |
"data": {
|
@@ -134,6 +131,66 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
134 |
finally:
|
135 |
print("WebSocket connection closed")
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
# @app.post("/api/train_single")
|
138 |
# async def train_single_model(config: TrainingConfig):
|
139 |
# try:
|
@@ -165,6 +222,36 @@ async def train_compare_models(config: ComparisonConfig):
|
|
165 |
except Exception as e:
|
166 |
raise HTTPException(status_code=500, detail=str(e))
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
@app.post("/api/inference")
|
169 |
async def perform_inference(data: dict):
|
170 |
try:
|
@@ -175,10 +262,23 @@ async def perform_inference(data: dict):
|
|
175 |
model_path = Path("scripts/training/models") / f"{model_name}.pth"
|
176 |
if not model_path.exists():
|
177 |
raise HTTPException(status_code=404, detail=f"Model not found: {model_path}")
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
-
#
|
180 |
-
model = Net(
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
model.eval()
|
183 |
|
184 |
# Process image data and get prediction
|
@@ -216,7 +316,15 @@ async def perform_inference(data: dict):
|
|
216 |
output = model(image_tensor)
|
217 |
prediction = output.argmax(dim=1).item()
|
218 |
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
except Exception as e:
|
222 |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
|
|
8 |
import torch
|
9 |
from scripts.model import Net
|
10 |
from scripts.training.train import train
|
|
|
|
|
11 |
from pathlib import Path
|
|
|
12 |
from fastapi import BackgroundTasks
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms")
|
16 |
|
17 |
app = FastAPI()
|
18 |
|
|
|
83 |
async def websocket_endpoint(websocket: WebSocket):
|
84 |
await websocket.accept()
|
85 |
try:
|
|
|
86 |
config_data = await websocket.receive_json()
|
87 |
|
|
|
88 |
model = Net(
|
89 |
kernels=[
|
90 |
config_data['block1'],
|
|
|
93 |
]
|
94 |
)
|
95 |
|
|
|
96 |
from scripts.training.config import NetworkConfig
|
97 |
config = NetworkConfig()
|
98 |
config.update(
|
|
|
101 |
block3=config_data['block3'],
|
102 |
optimizer=config_data['optimizer'],
|
103 |
batch_size=config_data['batch_size'],
|
104 |
+
epochs=config_data['epochs']
|
105 |
)
|
106 |
|
107 |
print(f"Starting training with config: {config_data}")
|
108 |
|
|
|
109 |
try:
|
110 |
+
# Pass "single" as model_type for single model training
|
111 |
+
await train(model, config, websocket, model_type="single")
|
112 |
await websocket.send_json({
|
113 |
"type": "training_complete",
|
114 |
"data": {
|
|
|
131 |
finally:
|
132 |
print("WebSocket connection closed")
|
133 |
|
134 |
+
@app.websocket("/ws/compare")
|
135 |
+
async def websocket_compare_endpoint(websocket: WebSocket):
|
136 |
+
await websocket.accept()
|
137 |
+
try:
|
138 |
+
data = await websocket.receive_json()
|
139 |
+
if data.get("type") == "start_comparison":
|
140 |
+
from scripts.training.config import NetworkConfig
|
141 |
+
|
142 |
+
# Create and train both models
|
143 |
+
model1_config = NetworkConfig()
|
144 |
+
model2_config = NetworkConfig()
|
145 |
+
|
146 |
+
# Update configs with received data
|
147 |
+
model1_config.update(**data["model1"])
|
148 |
+
model2_config.update(**data["model2"])
|
149 |
+
|
150 |
+
# Create models with respective configurations
|
151 |
+
model1 = Net(
|
152 |
+
kernels=[
|
153 |
+
model1_config.block1,
|
154 |
+
model1_config.block2,
|
155 |
+
model1_config.block3
|
156 |
+
]
|
157 |
+
)
|
158 |
+
|
159 |
+
model2 = Net(
|
160 |
+
kernels=[
|
161 |
+
model2_config.block1,
|
162 |
+
model2_config.block2,
|
163 |
+
model2_config.block3
|
164 |
+
]
|
165 |
+
)
|
166 |
+
|
167 |
+
# Train both models with appropriate model_type
|
168 |
+
try:
|
169 |
+
await train(model1, model1_config, websocket, model_type="model_1")
|
170 |
+
await train(model2, model2_config, websocket, model_type="model_2")
|
171 |
+
|
172 |
+
await websocket.send_json({
|
173 |
+
"type": "comparison_complete",
|
174 |
+
"data": {
|
175 |
+
"message": "Training completed successfully!"
|
176 |
+
}
|
177 |
+
})
|
178 |
+
except Exception as e:
|
179 |
+
print(f"Training error: {str(e)}")
|
180 |
+
await websocket.send_json({
|
181 |
+
"type": "training_error",
|
182 |
+
"data": {
|
183 |
+
"message": f"Training failed: {str(e)}"
|
184 |
+
}
|
185 |
+
})
|
186 |
+
|
187 |
+
except WebSocketDisconnect:
|
188 |
+
print("WebSocket disconnected")
|
189 |
+
except Exception as e:
|
190 |
+
print(f"WebSocket error: {str(e)}")
|
191 |
+
finally:
|
192 |
+
print("WebSocket connection closed")
|
193 |
+
|
194 |
# @app.post("/api/train_single")
|
195 |
# async def train_single_model(config: TrainingConfig):
|
196 |
# try:
|
|
|
222 |
except Exception as e:
|
223 |
raise HTTPException(status_code=500, detail=str(e))
|
224 |
|
225 |
+
def parse_model_filename(filename):
|
226 |
+
"""Extract configuration from model filename"""
|
227 |
+
# Example filename: single_arch_32_64_128_opt_adam_batch_64_20240322_123456.pth
|
228 |
+
try:
|
229 |
+
parts = filename.split('_')
|
230 |
+
# Find architecture values
|
231 |
+
arch_index = parts.index('arch')
|
232 |
+
block1 = int(parts[arch_index + 1])
|
233 |
+
block2 = int(parts[arch_index + 2])
|
234 |
+
block3 = int(parts[arch_index + 3])
|
235 |
+
|
236 |
+
# Find optimizer
|
237 |
+
opt_index = parts.index('opt')
|
238 |
+
optimizer = parts[opt_index + 1]
|
239 |
+
|
240 |
+
# Find batch size
|
241 |
+
batch_index = parts.index('batch')
|
242 |
+
batch_size = int(parts[batch_index + 1])
|
243 |
+
|
244 |
+
return {
|
245 |
+
'block1': block1,
|
246 |
+
'block2': block2,
|
247 |
+
'block3': block3,
|
248 |
+
'optimizer': optimizer,
|
249 |
+
'batch_size': batch_size
|
250 |
+
}
|
251 |
+
except Exception as e:
|
252 |
+
print(f"Error parsing model filename: {e}")
|
253 |
+
return None
|
254 |
+
|
255 |
@app.post("/api/inference")
|
256 |
async def perform_inference(data: dict):
|
257 |
try:
|
|
|
262 |
model_path = Path("scripts/training/models") / f"{model_name}.pth"
|
263 |
if not model_path.exists():
|
264 |
raise HTTPException(status_code=404, detail=f"Model not found: {model_path}")
|
265 |
+
|
266 |
+
# Parse model configuration from filename
|
267 |
+
config = parse_model_filename(model_name)
|
268 |
+
if not config:
|
269 |
+
raise HTTPException(status_code=500, detail="Could not parse model configuration")
|
270 |
|
271 |
+
# Create model with the correct configuration
|
272 |
+
model = Net(
|
273 |
+
kernels=[
|
274 |
+
config['block1'],
|
275 |
+
config['block2'],
|
276 |
+
config['block3']
|
277 |
+
]
|
278 |
+
)
|
279 |
+
|
280 |
+
# Load model weights
|
281 |
+
model.load_state_dict(torch.load(str(model_path), map_location=torch.device('cpu'), weights_only=True))
|
282 |
model.eval()
|
283 |
|
284 |
# Process image data and get prediction
|
|
|
316 |
output = model(image_tensor)
|
317 |
prediction = output.argmax(dim=1).item()
|
318 |
|
319 |
+
# Add configuration info to response
|
320 |
+
return {
|
321 |
+
"prediction": prediction,
|
322 |
+
"model_config": {
|
323 |
+
"architecture": f"{config['block1']}-{config['block2']}-{config['block3']}",
|
324 |
+
"optimizer": config['optimizer'],
|
325 |
+
"batch_size": config['batch_size']
|
326 |
+
}
|
327 |
+
}
|
328 |
|
329 |
except Exception as e:
|
330 |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
scripts/training/train.py
CHANGED
@@ -13,6 +13,17 @@ import shutil
|
|
13 |
from tqdm import tqdm
|
14 |
import asyncio
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def download_and_extract_mnist_data():
|
17 |
"""Download and extract MNIST dataset from a reliable mirror"""
|
18 |
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
|
@@ -107,7 +118,7 @@ def validate(model, test_loader, criterion, device):
|
|
107 |
|
108 |
return val_loss, val_acc
|
109 |
|
110 |
-
async def train(model, config, websocket=None):
|
111 |
print("\nStarting training...")
|
112 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
113 |
print(f"Using device: {device}")
|
@@ -140,27 +151,25 @@ async def train(model, config, websocket=None):
|
|
140 |
|
141 |
print(f"Dataset loaded. Training samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
|
142 |
|
143 |
-
# Initialize optimizer based on config
|
144 |
-
if config.optimizer.lower() == 'adam':
|
145 |
-
optimizer = optim.Adam(model.parameters())
|
146 |
-
else:
|
147 |
-
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
|
148 |
-
|
149 |
-
criterion = nn.CrossEntropyLoss()
|
150 |
-
|
151 |
print("\nTraining Configuration:")
|
|
|
152 |
print(f"Optimizer: {config.optimizer}")
|
153 |
print(f"Batch Size: {config.batch_size}")
|
154 |
print(f"Network Architecture: {config.block1}-{config.block2}-{config.block3}")
|
155 |
print("\nStarting training loop...")
|
156 |
|
157 |
best_val_acc = 0
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
try:
|
166 |
for epoch in range(config.epochs):
|
@@ -204,10 +213,11 @@ async def train(model, config, websocket=None):
|
|
204 |
# Send training update through websocket
|
205 |
if websocket:
|
206 |
try:
|
|
|
207 |
await websocket.send_json({
|
208 |
'type': 'training_update',
|
209 |
'data': {
|
210 |
-
'step':
|
211 |
'train_loss': current_loss,
|
212 |
'train_acc': current_acc
|
213 |
}
|
@@ -215,10 +225,6 @@ async def train(model, config, websocket=None):
|
|
215 |
except Exception as e:
|
216 |
print(f"Error sending websocket update: {e}")
|
217 |
|
218 |
-
# Calculate epoch metrics
|
219 |
-
train_loss = total_loss / len(train_loader)
|
220 |
-
train_acc = 100. * correct / total
|
221 |
-
|
222 |
# Validation phase
|
223 |
model.eval()
|
224 |
val_loss = 0
|
@@ -240,7 +246,7 @@ async def train(model, config, websocket=None):
|
|
240 |
|
241 |
# Print epoch results
|
242 |
print(f"\nEpoch {epoch+1}/{config.epochs} Results:")
|
243 |
-
print(f"Training Loss: {
|
244 |
print(f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%")
|
245 |
|
246 |
# Send validation update through websocket
|
@@ -257,12 +263,17 @@ async def train(model, config, websocket=None):
|
|
257 |
except Exception as e:
|
258 |
print(f"Error sending websocket update: {e}")
|
259 |
|
260 |
-
# Save best model
|
261 |
if val_acc > best_val_acc:
|
262 |
best_val_acc = val_acc
|
263 |
print(f"\nNew best validation accuracy: {val_acc:.2f}%")
|
264 |
-
|
265 |
-
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
except Exception as e:
|
268 |
print(f"\nError during training: {e}")
|
@@ -270,4 +281,4 @@ async def train(model, config, websocket=None):
|
|
270 |
|
271 |
print("\nTraining completed!")
|
272 |
print(f"Best validation accuracy: {best_val_acc:.2f}%")
|
273 |
-
return
|
|
|
13 |
from tqdm import tqdm
|
14 |
import asyncio
|
15 |
|
16 |
+
def generate_model_filename(config, model_type="single"):
|
17 |
+
"""Generate a filename based on model configuration
|
18 |
+
model_type can be "single", "model_1", or "model_2"
|
19 |
+
"""
|
20 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
21 |
+
arch = f"{config.block1}_{config.block2}_{config.block3}"
|
22 |
+
opt = config.optimizer.lower()
|
23 |
+
batch = str(config.batch_size)
|
24 |
+
|
25 |
+
return f"{model_type}_arch_{arch}_opt_{opt}_batch_{batch}_{timestamp}.pth"
|
26 |
+
|
27 |
def download_and_extract_mnist_data():
|
28 |
"""Download and extract MNIST dataset from a reliable mirror"""
|
29 |
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
|
|
|
118 |
|
119 |
return val_loss, val_acc
|
120 |
|
121 |
+
async def train(model, config, websocket=None, model_type="single"):
|
122 |
print("\nStarting training...")
|
123 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
124 |
print(f"Using device: {device}")
|
|
|
151 |
|
152 |
print(f"Dataset loaded. Training samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
print("\nTraining Configuration:")
|
155 |
+
print(f"Epochs: {config.epochs}")
|
156 |
print(f"Optimizer: {config.optimizer}")
|
157 |
print(f"Batch Size: {config.batch_size}")
|
158 |
print(f"Network Architecture: {config.block1}-{config.block2}-{config.block3}")
|
159 |
print("\nStarting training loop...")
|
160 |
|
161 |
best_val_acc = 0
|
162 |
+
criterion = nn.CrossEntropyLoss()
|
163 |
+
|
164 |
+
# Initialize optimizer based on config
|
165 |
+
if config.optimizer.lower() == 'adam':
|
166 |
+
optimizer = optim.Adam(model.parameters())
|
167 |
+
else:
|
168 |
+
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
|
169 |
+
|
170 |
+
# Create models directory if it doesn't exist
|
171 |
+
models_dir = Path("scripts/training/models")
|
172 |
+
models_dir.mkdir(parents=True, exist_ok=True)
|
173 |
|
174 |
try:
|
175 |
for epoch in range(config.epochs):
|
|
|
213 |
# Send training update through websocket
|
214 |
if websocket:
|
215 |
try:
|
216 |
+
step = batch_idx + epoch * len(train_loader)
|
217 |
await websocket.send_json({
|
218 |
'type': 'training_update',
|
219 |
'data': {
|
220 |
+
'step': step,
|
221 |
'train_loss': current_loss,
|
222 |
'train_acc': current_acc
|
223 |
}
|
|
|
225 |
except Exception as e:
|
226 |
print(f"Error sending websocket update: {e}")
|
227 |
|
|
|
|
|
|
|
|
|
228 |
# Validation phase
|
229 |
model.eval()
|
230 |
val_loss = 0
|
|
|
246 |
|
247 |
# Print epoch results
|
248 |
print(f"\nEpoch {epoch+1}/{config.epochs} Results:")
|
249 |
+
print(f"Training Loss: {current_loss:.4f} | Training Accuracy: {current_acc:.2f}%")
|
250 |
print(f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%")
|
251 |
|
252 |
# Send validation update through websocket
|
|
|
263 |
except Exception as e:
|
264 |
print(f"Error sending websocket update: {e}")
|
265 |
|
266 |
+
# Save best model with configuration in filename
|
267 |
if val_acc > best_val_acc:
|
268 |
best_val_acc = val_acc
|
269 |
print(f"\nNew best validation accuracy: {val_acc:.2f}%")
|
270 |
+
|
271 |
+
# Generate filename with configuration
|
272 |
+
model_filename = generate_model_filename(config, model_type)
|
273 |
+
model_path = models_dir / model_filename
|
274 |
+
|
275 |
+
print(f"Saving model as: {model_filename}")
|
276 |
+
torch.save(model.state_dict(), model_path)
|
277 |
|
278 |
except Exception as e:
|
279 |
print(f"\nError during training: {e}")
|
|
|
281 |
|
282 |
print("\nTraining completed!")
|
283 |
print(f"Best validation accuracy: {best_val_acc:.2f}%")
|
284 |
+
return None
|
static/js/inference.js
CHANGED
@@ -46,8 +46,18 @@ function setupCanvas() {
|
|
46 |
}
|
47 |
|
48 |
function clearCanvas() {
|
|
|
|
|
|
|
|
|
49 |
ctx.fillStyle = "white";
|
50 |
ctx.fillRect(0, 0, canvas.width, canvas.height);
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
}
|
52 |
|
53 |
async function predict() {
|
|
|
46 |
}
|
47 |
|
48 |
function clearCanvas() {
|
49 |
+
const canvas = document.getElementById('drawing-canvas');
|
50 |
+
const ctx = canvas.getContext('2d');
|
51 |
+
|
52 |
+
// Clear the canvas
|
53 |
ctx.fillStyle = "white";
|
54 |
ctx.fillRect(0, 0, canvas.width, canvas.height);
|
55 |
+
ctx.beginPath();
|
56 |
+
|
57 |
+
// Hide and clear prediction result
|
58 |
+
const resultDiv = document.getElementById('prediction-result');
|
59 |
+
resultDiv.classList.add('hidden');
|
60 |
+
resultDiv.innerHTML = '';
|
61 |
}
|
62 |
|
63 |
async function predict() {
|
templates/train_compare.html
CHANGED
@@ -67,6 +67,14 @@
|
|
67 |
<option value="128">128</option>
|
68 |
</select>
|
69 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
</div>
|
71 |
</div>
|
72 |
|
@@ -126,6 +134,14 @@
|
|
126 |
<option value="128">128</option>
|
127 |
</select>
|
128 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
</div>
|
130 |
</div>
|
131 |
</div>
|
@@ -343,7 +359,8 @@
|
|
343 |
block2: parseInt(document.getElementById('model1_block2').value),
|
344 |
block3: parseInt(document.getElementById('model1_block3').value),
|
345 |
optimizer: document.getElementById('model1_optimizer').value,
|
346 |
-
batch_size: parseInt(document.getElementById('model1_batch_size').value)
|
|
|
347 |
};
|
348 |
|
349 |
const model2Config = {
|
@@ -351,7 +368,8 @@
|
|
351 |
block2: parseInt(document.getElementById('model2_block2').value),
|
352 |
block3: parseInt(document.getElementById('model2_block3').value),
|
353 |
optimizer: document.getElementById('model2_optimizer').value,
|
354 |
-
batch_size: parseInt(document.getElementById('model2_batch_size').value)
|
|
|
355 |
};
|
356 |
|
357 |
// Setup WebSocket connection
|
|
|
67 |
<option value="128">128</option>
|
68 |
</select>
|
69 |
</div>
|
70 |
+
<div class="config-item">
|
71 |
+
<label for="model1_epochs">Epochs:</label>
|
72 |
+
<select id="model1_epochs" name="epochs">
|
73 |
+
<option value="1">1</option>
|
74 |
+
<option value="2">2</option>
|
75 |
+
<option value="3">3</option>
|
76 |
+
</select>
|
77 |
+
</div>
|
78 |
</div>
|
79 |
</div>
|
80 |
|
|
|
134 |
<option value="128">128</option>
|
135 |
</select>
|
136 |
</div>
|
137 |
+
<div class="config-item">
|
138 |
+
<label for="model2_epochs">Epochs:</label>
|
139 |
+
<select id="model2_epochs" name="epochs">
|
140 |
+
<option value="1">1</option>
|
141 |
+
<option value="2">2</option>
|
142 |
+
<option value="3">3</option>
|
143 |
+
</select>
|
144 |
+
</div>
|
145 |
</div>
|
146 |
</div>
|
147 |
</div>
|
|
|
359 |
block2: parseInt(document.getElementById('model1_block2').value),
|
360 |
block3: parseInt(document.getElementById('model1_block3').value),
|
361 |
optimizer: document.getElementById('model1_optimizer').value,
|
362 |
+
batch_size: parseInt(document.getElementById('model1_batch_size').value),
|
363 |
+
epochs: parseInt(document.getElementById('model1_epochs').value)
|
364 |
};
|
365 |
|
366 |
const model2Config = {
|
|
|
368 |
block2: parseInt(document.getElementById('model2_block2').value),
|
369 |
block3: parseInt(document.getElementById('model2_block3').value),
|
370 |
optimizer: document.getElementById('model2_optimizer').value,
|
371 |
+
batch_size: parseInt(document.getElementById('model2_batch_size').value),
|
372 |
+
epochs: parseInt(document.getElementById('model2_epochs').value)
|
373 |
};
|
374 |
|
375 |
// Setup WebSocket connection
|
templates/train_single.html
CHANGED
@@ -67,6 +67,14 @@
|
|
67 |
<option value="128">128</option>
|
68 |
</select>
|
69 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
</div>
|
71 |
</div>
|
72 |
|
@@ -81,6 +89,13 @@
|
|
81 |
<div id="lossChart"></div>
|
82 |
<div id="accuracyChart"></div>
|
83 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
</div>
|
85 |
|
86 |
<script>
|
@@ -215,7 +230,7 @@
|
|
215 |
block3: parseInt(document.getElementById('block3').value),
|
216 |
optimizer: document.getElementById('optimizer').value,
|
217 |
batch_size: parseInt(document.getElementById('batch_size').value),
|
218 |
-
epochs:
|
219 |
};
|
220 |
ws.send(JSON.stringify(config));
|
221 |
};
|
@@ -261,6 +276,8 @@
|
|
261 |
else if (data.type === 'training_complete') {
|
262 |
alert(data.data.message);
|
263 |
stopTraining();
|
|
|
|
|
264 |
}
|
265 |
else if (data.type === 'training_error') {
|
266 |
alert(data.data.message);
|
@@ -369,6 +386,22 @@
|
|
369 |
height: 400px;
|
370 |
width: 100%;
|
371 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
</style>
|
373 |
</body>
|
374 |
</html>
|
|
|
67 |
<option value="128">128</option>
|
68 |
</select>
|
69 |
</div>
|
70 |
+
<div class="config-item">
|
71 |
+
<label for="epochs">Epochs:</label>
|
72 |
+
<select id="epochs" name="epochs">
|
73 |
+
<option value="1">1</option>
|
74 |
+
<option value="2">2</option>
|
75 |
+
<option value="3">3</option>
|
76 |
+
</select>
|
77 |
+
</div>
|
78 |
</div>
|
79 |
</div>
|
80 |
|
|
|
89 |
<div id="lossChart"></div>
|
90 |
<div id="accuracyChart"></div>
|
91 |
</div>
|
92 |
+
|
93 |
+
<!-- Inference Controls -->
|
94 |
+
<div class="inference-controls" style="display: none;">
|
95 |
+
<button id="goToInference" onclick="window.location.href='/inference'" class="inference-button">
|
96 |
+
Try Model Inference
|
97 |
+
</button>
|
98 |
+
</div>
|
99 |
</div>
|
100 |
|
101 |
<script>
|
|
|
230 |
block3: parseInt(document.getElementById('block3').value),
|
231 |
optimizer: document.getElementById('optimizer').value,
|
232 |
batch_size: parseInt(document.getElementById('batch_size').value),
|
233 |
+
epochs: parseInt(document.getElementById('epochs').value)
|
234 |
};
|
235 |
ws.send(JSON.stringify(config));
|
236 |
};
|
|
|
276 |
else if (data.type === 'training_complete') {
|
277 |
alert(data.data.message);
|
278 |
stopTraining();
|
279 |
+
// Show the inference button
|
280 |
+
document.querySelector('.inference-controls').style.display = 'block';
|
281 |
}
|
282 |
else if (data.type === 'training_error') {
|
283 |
alert(data.data.message);
|
|
|
386 |
height: 400px;
|
387 |
width: 100%;
|
388 |
}
|
389 |
+
|
390 |
+
.inference-controls {
|
391 |
+
margin: 20px 0;
|
392 |
+
text-align: center;
|
393 |
+
}
|
394 |
+
|
395 |
+
.inference-button {
|
396 |
+
background-color: #28a745;
|
397 |
+
padding: 12px 24px;
|
398 |
+
font-size: 1.1em;
|
399 |
+
transition: background-color 0.3s;
|
400 |
+
}
|
401 |
+
|
402 |
+
.inference-button:hover {
|
403 |
+
background-color: #218838;
|
404 |
+
}
|
405 |
</style>
|
406 |
</body>
|
407 |
</html>
|