Spaces:
Sleeping
Sleeping
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Train Single Model - MNIST</title> | |
| <link rel="stylesheet" href="{{ url_for('static', path='/css/style.css') }}"> | |
| <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Roboto+Mono&display=swap" rel="stylesheet"> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>Train Single Model</h1> | |
| <!-- Network Architecture Section --> | |
| <div class="model-config"> | |
| <h3>Model Configuration</h3> | |
| <div class="network-config"> | |
| <h4>Network Architecture</h4> | |
| <div class="block-config"> | |
| <div class="block"> | |
| <label for="block1">Block-1:</label> | |
| <select id="block1" name="block1" class="form-select"> | |
| <option value="8">8</option> | |
| <option value="16">16</option> | |
| <option value="32" selected>32</option> | |
| <option value="64">64</option> | |
| <option value="128">128</option> | |
| </select> | |
| </div> | |
| <div class="block"> | |
| <label for="block2">Block-2:</label> | |
| <select id="block2" name="block2" class="form-select"> | |
| <option value="8">8</option> | |
| <option value="16">16</option> | |
| <option value="32">32</option> | |
| <option value="64" selected>64</option> | |
| <option value="128">128</option> | |
| </select> | |
| </div> | |
| <div class="block"> | |
| <label for="block3">Block-3:</label> | |
| <select id="block3" name="block3" class="form-select"> | |
| <option value="8">8</option> | |
| <option value="16">16</option> | |
| <option value="32">32</option> | |
| <option value="64">64</option> | |
| <option value="128" selected>128</option> | |
| </select> | |
| </div> | |
| </div> | |
| </div> | |
| <div class="training-config"> | |
| <div class="config-item"> | |
| <label for="optimizer">Optimizer:</label> | |
| <select id="optimizer" name="optimizer"> | |
| <option value="SGD" selected>SGD</option> | |
| <option value="Adam">Adam</option> | |
| </select> | |
| </div> | |
| <div class="config-item"> | |
| <label for="batch_size">Batch Size:</label> | |
| <select id="batch_size" name="batch_size"> | |
| <option value="32">32</option> | |
| <option value="64" selected>64</option> | |
| <option value="128">128</option> | |
| </select> | |
| </div> | |
| <div class="config-item"> | |
| <label for="epochs">Epochs:</label> | |
| <select id="epochs" name="epochs"> | |
| <option value="1">1</option> | |
| <option value="2">2</option> | |
| <option value="3">3</option> | |
| </select> | |
| </div> | |
| </div> | |
| </div> | |
| <!-- Training Controls --> | |
| <div class="controls"> | |
| <button id="startTraining" onclick="startTraining()">Start Training</button> | |
| <button id="stopTraining" onclick="stopTraining()" disabled>Stop Training</button> | |
| </div> | |
| <!-- Training Progress --> | |
| <div class="charts-container"> | |
| <div id="lossChart"></div> | |
| <div id="accuracyChart"></div> | |
| </div> | |
| <!-- Inference Controls --> | |
| <div class="inference-controls" style="display: none;"> | |
| <button id="goToInference" onclick="window.location.href='/inference'" class="inference-button"> | |
| Try Model Inference | |
| </button> | |
| </div> | |
| </div> | |
| <script> | |
| let ws; | |
| let lossChart; | |
| let accuracyChart; | |
| // Initialize charts | |
| document.addEventListener('DOMContentLoaded', function() { | |
| // Loss chart configuration | |
| const lossData = [ | |
| { | |
| x: [], | |
| y: [], | |
| name: 'Training Loss', | |
| type: 'scatter' | |
| }, | |
| { | |
| x: [], | |
| y: [], | |
| name: 'Validation Loss', | |
| type: 'scatter' | |
| } | |
| ]; | |
| const lossLayout = { | |
| title: 'Loss', | |
| xaxis: { | |
| title: 'Iterations', | |
| rangemode: 'nonnegative' | |
| }, | |
| yaxis: { | |
| title: 'Loss', | |
| rangemode: 'nonnegative' | |
| } | |
| }; | |
| // Accuracy chart configuration | |
| const accuracyData = [ | |
| { | |
| x: [], | |
| y: [], | |
| name: 'Training Accuracy', | |
| type: 'scatter' | |
| }, | |
| { | |
| x: [], | |
| y: [], | |
| name: 'Validation Accuracy', | |
| type: 'scatter' | |
| } | |
| ]; | |
| const accuracyLayout = { | |
| title: 'Accuracy', | |
| xaxis: { | |
| title: 'Iterations', | |
| rangemode: 'nonnegative' | |
| }, | |
| yaxis: { | |
| title: 'Accuracy (%)', | |
| range: [0, 100] | |
| } | |
| }; | |
| // Create charts | |
| Plotly.newPlot('lossChart', lossData, lossLayout); | |
| Plotly.newPlot('accuracyChart', accuracyData, accuracyLayout); | |
| }); | |
| function startTraining() { | |
| // Disable start button and enable stop button | |
| document.getElementById('startTraining').disabled = true; | |
| document.getElementById('stopTraining').disabled = false; | |
| // Clear previous charts | |
| Plotly.purge('lossChart'); | |
| Plotly.purge('accuracyChart'); | |
| // Initialize new charts | |
| const lossData = [ | |
| { | |
| x: [], | |
| y: [], | |
| name: 'Training Loss', | |
| type: 'scatter' | |
| }, | |
| { | |
| x: [], | |
| y: [], | |
| name: 'Validation Loss', | |
| type: 'scatter' | |
| } | |
| ]; | |
| const accuracyData = [ | |
| { | |
| x: [], | |
| y: [], | |
| name: 'Training Accuracy', | |
| type: 'scatter' | |
| }, | |
| { | |
| x: [], | |
| y: [], | |
| name: 'Validation Accuracy', | |
| type: 'scatter' | |
| } | |
| ]; | |
| Plotly.newPlot('lossChart', lossData, { | |
| title: 'Loss', | |
| xaxis: { title: 'Iterations', rangemode: 'nonnegative' }, | |
| yaxis: { title: 'Loss', rangemode: 'nonnegative' } | |
| }); | |
| Plotly.newPlot('accuracyChart', accuracyData, { | |
| title: 'Accuracy', | |
| xaxis: { title: 'Iterations', rangemode: 'nonnegative' }, | |
| yaxis: { title: 'Accuracy (%)', range: [0, 100] } | |
| }); | |
| // Setup WebSocket connection | |
| ws = new WebSocket(`ws://${window.location.host}/ws/train`); | |
| ws.onopen = function() { | |
| console.log("WebSocket connection established"); | |
| // Send configuration through WebSocket | |
| const config = { | |
| block1: parseInt(document.getElementById('block1').value), | |
| block2: parseInt(document.getElementById('block2').value), | |
| block3: parseInt(document.getElementById('block3').value), | |
| optimizer: document.getElementById('optimizer').value, | |
| batch_size: parseInt(document.getElementById('batch_size').value), | |
| epochs: parseInt(document.getElementById('epochs').value) | |
| }; | |
| ws.send(JSON.stringify(config)); | |
| }; | |
| ws.onerror = function(error) { | |
| console.error("WebSocket error:", error); | |
| stopTraining(); | |
| alert("Error connecting to training server"); | |
| }; | |
| ws.onclose = function() { | |
| console.log("WebSocket connection closed"); | |
| stopTraining(); | |
| }; | |
| ws.onmessage = function(event) { | |
| const data = JSON.parse(event.data); | |
| if (data.type === 'training_update') { | |
| // Update training metrics (trace index 0) | |
| Plotly.extendTraces('lossChart', { | |
| x: [[data.data.step]], | |
| y: [[data.data.train_loss]] | |
| }, [0]); | |
| Plotly.extendTraces('accuracyChart', { | |
| x: [[data.data.step]], | |
| y: [[data.data.train_acc]] | |
| }, [0]); | |
| } | |
| else if (data.type === 'validation_update') { | |
| // Update validation metrics (trace index 1) | |
| Plotly.extendTraces('lossChart', { | |
| x: [[data.data.step]], | |
| y: [[data.data.val_loss]] | |
| }, [1]); | |
| Plotly.extendTraces('accuracyChart', { | |
| x: [[data.data.step]], | |
| y: [[data.data.val_acc]] | |
| }, [1]); | |
| } | |
| else if (data.type === 'training_complete') { | |
| alert(data.data.message); | |
| stopTraining(); | |
| // Show the inference button | |
| document.querySelector('.inference-controls').style.display = 'block'; | |
| } | |
| else if (data.type === 'training_error') { | |
| alert(data.data.message); | |
| stopTraining(); | |
| } | |
| }; | |
| } | |
| function stopTraining() { | |
| if (ws) { | |
| ws.close(); | |
| } | |
| document.getElementById('startTraining').disabled = false; | |
| document.getElementById('stopTraining').disabled = true; | |
| } | |
| </script> | |
| <style> | |
| .container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| padding: 20px; | |
| } | |
| .model-config { | |
| padding: 20px; | |
| border: 1px solid #ddd; | |
| border-radius: 5px; | |
| margin-bottom: 20px; | |
| } | |
| .network-config { | |
| margin-bottom: 20px; | |
| } | |
| .network-config h4 { | |
| margin: 0 0 15px 0; | |
| font-size: 1.1em; | |
| } | |
| .block-config { | |
| display: flex; | |
| justify-content: space-between; | |
| gap: 20px; | |
| } | |
| .block { | |
| flex: 1; | |
| } | |
| .block label { | |
| display: block; | |
| margin-bottom: 5px; | |
| font-weight: bold; | |
| } | |
| .training-config { | |
| display: flex; | |
| gap: 20px; | |
| } | |
| .config-item { | |
| flex: 1; | |
| } | |
| .config-item label { | |
| display: block; | |
| margin-bottom: 5px; | |
| font-weight: bold; | |
| } | |
| select { | |
| width: 100%; | |
| padding: 8px; | |
| border: 1px solid #ddd; | |
| border-radius: 4px; | |
| } | |
| .controls { | |
| margin: 20px 0; | |
| } | |
| button { | |
| padding: 10px 20px; | |
| margin-right: 10px; | |
| border: none; | |
| border-radius: 4px; | |
| background-color: #007bff; | |
| color: white; | |
| cursor: pointer; | |
| } | |
| button:disabled { | |
| background-color: #ccc; | |
| cursor: not-allowed; | |
| } | |
| .charts-container { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 20px; | |
| margin-top: 20px; | |
| } | |
| #lossChart, #accuracyChart { | |
| height: 400px; | |
| width: 100%; | |
| } | |
| .inference-controls { | |
| margin: 20px 0; | |
| text-align: center; | |
| } | |
| .inference-button { | |
| background-color: #28a745; | |
| padding: 12px 24px; | |
| font-size: 1.1em; | |
| transition: background-color 0.3s; | |
| } | |
| .inference-button:hover { | |
| background-color: #218838; | |
| } | |
| </style> | |
| </body> | |
| </html> |