Shilpaj commited on
Commit
4c1a791
·
1 Parent(s): 9f13b34

Feat: Craete frontend and backend for the project

Browse files
static/css/style.css ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --primary-color: #6366f1;
3
+ --background-dark: #0f172a;
4
+ --text-light: #e2e8f0;
5
+ --text-gray: #94a3b8;
6
+ --card-bg: #1e293b;
7
+ --gradient-start: #818cf8;
8
+ --gradient-end: #6366f1;
9
+ --gradient-hover-start: #6366f1;
10
+ --gradient-hover-end: #4f46e5;
11
+ }
12
+
13
+ body {
14
+ font-family: 'Inter', sans-serif;
15
+ margin: 0;
16
+ padding: 0;
17
+ background-color: var(--background-dark);
18
+ color: var(--text-light);
19
+ }
20
+
21
+ .container {
22
+ max-width: 1200px;
23
+ margin: 0 auto;
24
+ padding: 2rem;
25
+ }
26
+
27
+ h1 {
28
+ font-size: 3rem;
29
+ font-weight: 700;
30
+ text-align: center;
31
+ margin-bottom: 2rem;
32
+ background: linear-gradient(to right, #818cf8, #6366f1);
33
+ -webkit-background-clip: text;
34
+ -webkit-text-fill-color: transparent;
35
+ }
36
+
37
+ .button-container {
38
+ display: flex;
39
+ gap: 1rem;
40
+ justify-content: center;
41
+ margin: 2rem 0;
42
+ }
43
+
44
+ .btn {
45
+ padding: 0.75rem 1.5rem;
46
+ background: linear-gradient(135deg, var(--gradient-start), var(--gradient-end));
47
+ color: white;
48
+ border: none;
49
+ border-radius: 0.5rem;
50
+ cursor: pointer;
51
+ text-decoration: none;
52
+ font-size: 1rem;
53
+ font-weight: 500;
54
+ transition: all 0.3s ease;
55
+ position: relative;
56
+ z-index: 1;
57
+ overflow: hidden;
58
+ }
59
+
60
+ .btn::before {
61
+ content: '';
62
+ position: absolute;
63
+ top: 0;
64
+ left: 0;
65
+ right: 0;
66
+ bottom: 0;
67
+ background: linear-gradient(135deg, var(--gradient-hover-start), var(--gradient-hover-end));
68
+ opacity: 0;
69
+ transition: opacity 0.3s ease;
70
+ z-index: -1;
71
+ }
72
+
73
+ .btn:hover::before {
74
+ opacity: 1;
75
+ }
76
+
77
+ .btn:hover {
78
+ transform: translateY(-2px);
79
+ box-shadow: 0 4px 15px rgba(99, 102, 241, 0.5);
80
+ }
81
+
82
+ .card {
83
+ background-color: var(--card-bg);
84
+ border-radius: 1rem;
85
+ padding: 2rem;
86
+ margin: 1rem 0;
87
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
88
+ position: relative;
89
+ }
90
+
91
+ .card::before {
92
+ content: '';
93
+ position: absolute;
94
+ top: -1px;
95
+ left: -1px;
96
+ right: -1px;
97
+ bottom: -1px;
98
+ background: linear-gradient(135deg, var(--gradient-start), var(--gradient-end));
99
+ border-radius: 1rem;
100
+ z-index: -1;
101
+ opacity: 0.1;
102
+ transition: opacity 0.3s ease;
103
+ }
104
+
105
+ .card:hover::before {
106
+ opacity: 0.2;
107
+ }
108
+
109
+ .form-group {
110
+ margin-bottom: 1.5rem;
111
+ }
112
+
113
+ .form-group label {
114
+ display: block;
115
+ margin-bottom: 0.5rem;
116
+ color: var(--text-gray);
117
+ }
118
+
119
+ input[type="number"],
120
+ select {
121
+ width: 100%;
122
+ padding: 0.75rem;
123
+ border: 1px solid;
124
+ border-image: linear-gradient(135deg, var(--gradient-start), var(--gradient-end)) 1;
125
+ border-radius: 0.5rem;
126
+ background-color: #374151;
127
+ color: var(--text-light);
128
+ margin-bottom: 0.5rem;
129
+ transition: all 0.3s ease;
130
+ }
131
+
132
+ input[type="number"]:focus,
133
+ select:focus {
134
+ outline: none;
135
+ box-shadow: 0 0 0 2px rgba(99, 102, 241, 0.3);
136
+ }
137
+
138
+ #drawing-canvas {
139
+ background-color: white;
140
+ border-radius: 1rem;
141
+ margin: 2rem auto;
142
+ display: block;
143
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
144
+ }
145
+
146
+ .training-form {
147
+ background-color: var(--card-bg);
148
+ border-radius: 1rem;
149
+ padding: 2rem;
150
+ margin-top: 2rem;
151
+ }
152
+
153
+ .results {
154
+ background-color: var(--card-bg);
155
+ border-radius: 1rem;
156
+ padding: 2rem;
157
+ margin-top: 2rem;
158
+ }
159
+
160
+ #training-logs {
161
+ font-family: 'Roboto Mono', monospace;
162
+ color: var(--text-gray);
163
+ padding: 1rem;
164
+ border-radius: 0.5rem;
165
+ background-color: #374151;
166
+ margin-top: 1rem;
167
+ border: 1px solid;
168
+ border-image: linear-gradient(135deg, var(--gradient-start), var(--gradient-end)) 1;
169
+ }
170
+
171
+ .plot-container {
172
+ background-color: var(--card-bg);
173
+ border-radius: 1rem;
174
+ padding: 1rem;
175
+ margin: 1rem 0;
176
+ border: 1px solid;
177
+ border-image: linear-gradient(135deg, var(--gradient-start), var(--gradient-end)) 1;
178
+ }
179
+
180
+ h2, h3 {
181
+ background: linear-gradient(to right, var(--gradient-start), var(--gradient-end));
182
+ -webkit-background-clip: text;
183
+ -webkit-text-fill-color: transparent;
184
+ margin-bottom: 1.5rem;
185
+ }
186
+
187
+ .features-grid {
188
+ display: grid;
189
+ grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
190
+ gap: 1.5rem;
191
+ margin-top: 2rem;
192
+ }
193
+
194
+ .hero-text {
195
+ font-size: 1.25rem;
196
+ text-align: center;
197
+ color: var(--text-light);
198
+ line-height: 1.6;
199
+ }
200
+
201
+ @media (max-width: 768px) {
202
+ .container {
203
+ padding: 1rem;
204
+ }
205
+
206
+ h1 {
207
+ font-size: 2.5rem;
208
+ }
209
+
210
+ .features-grid {
211
+ grid-template-columns: 1fr;
212
+ }
213
+ }
214
+
215
+ .models-grid {
216
+ display: grid;
217
+ grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
218
+ gap: 2rem;
219
+ margin-bottom: 2rem;
220
+ }
221
+
222
+ .model-config {
223
+ background: var(--card-bg);
224
+ padding: 1.5rem;
225
+ border-radius: 1rem;
226
+ position: relative;
227
+ }
228
+
229
+ .model-config::before {
230
+ content: '';
231
+ position: absolute;
232
+ top: -1px;
233
+ left: -1px;
234
+ right: -1px;
235
+ bottom: -1px;
236
+ background: linear-gradient(135deg, var(--gradient-start), var(--gradient-end));
237
+ border-radius: 1rem;
238
+ z-index: -1;
239
+ opacity: 0.1;
240
+ transition: opacity 0.3s ease;
241
+ }
242
+
243
+ .model-config:hover::before {
244
+ opacity: 0.2;
245
+ }
246
+
247
+ .kernel-inputs {
248
+ display: grid;
249
+ grid-template-columns: repeat(2, 1fr);
250
+ gap: 0.5rem;
251
+ }
252
+
253
+ #comparison-logs {
254
+ display: grid;
255
+ grid-template-columns: repeat(2, 1fr);
256
+ gap: 1rem;
257
+ }
258
+
259
+ .comparison-model {
260
+ padding: 1rem;
261
+ background: #374151;
262
+ border-radius: 0.5rem;
263
+ }
264
+
265
+ .comparison-model h4 {
266
+ margin-top: 0;
267
+ color: var(--text-light);
268
+ }
269
+
270
+ @media (max-width: 768px) {
271
+ .models-grid {
272
+ grid-template-columns: 1fr;
273
+ }
274
+
275
+ #comparison-logs {
276
+ grid-template-columns: 1fr;
277
+ }
278
+ }
static/js/inference.js ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ let canvas, ctx;
2
+
3
+ window.onload = function() {
4
+ canvas = document.getElementById('drawing-canvas');
5
+ ctx = canvas.getContext('2d');
6
+
7
+ setupCanvas();
8
+ };
9
+
10
+ function setupCanvas() {
11
+ ctx.fillStyle = "white";
12
+ ctx.fillRect(0, 0, canvas.width, canvas.height);
13
+
14
+ let drawing = false;
15
+
16
+ canvas.addEventListener('mousedown', startDrawing);
17
+ canvas.addEventListener('mousemove', draw);
18
+ canvas.addEventListener('mouseup', stopDrawing);
19
+ canvas.addEventListener('mouseout', stopDrawing);
20
+
21
+ function startDrawing(e) {
22
+ drawing = true;
23
+ draw(e);
24
+ }
25
+
26
+ function draw(e) {
27
+ if (!drawing) return;
28
+
29
+ const rect = canvas.getBoundingClientRect();
30
+ const x = e.clientX - rect.left;
31
+ const y = e.clientY - rect.top;
32
+
33
+ ctx.lineWidth = 15;
34
+ ctx.lineCap = 'round';
35
+ ctx.strokeStyle = 'black';
36
+ ctx.lineTo(x, y);
37
+ ctx.stroke();
38
+ ctx.beginPath();
39
+ ctx.moveTo(x, y);
40
+ }
41
+
42
+ function stopDrawing() {
43
+ drawing = false;
44
+ ctx.beginPath();
45
+ }
46
+ }
47
+
48
+ function clearCanvas() {
49
+ ctx.fillStyle = "white";
50
+ ctx.fillRect(0, 0, canvas.width, canvas.height);
51
+ }
52
+
53
+ async function predict() {
54
+ const modelSelect = document.getElementById('model-select');
55
+ const selectedModel = modelSelect.value;
56
+
57
+ if (!selectedModel) {
58
+ alert('Please train a model first');
59
+ return;
60
+ }
61
+
62
+ const imageData = canvas.toDataURL('image/png');
63
+
64
+ try {
65
+ const response = await fetch('/api/inference', {
66
+ method: 'POST',
67
+ headers: {
68
+ 'Content-Type': 'application/json',
69
+ },
70
+ body: JSON.stringify({
71
+ image: imageData,
72
+ model_name: selectedModel
73
+ })
74
+ });
75
+
76
+ if (!response.ok) {
77
+ const error = await response.json();
78
+ throw new Error(error.detail || 'Prediction failed');
79
+ }
80
+
81
+ const data = await response.json();
82
+ displayPrediction(data.prediction);
83
+ } catch (error) {
84
+ console.error('Error:', error);
85
+ alert(error.message || 'Error during prediction');
86
+ }
87
+ }
88
+
89
+ function displayPrediction(prediction) {
90
+ const resultDiv = document.getElementById('prediction-result');
91
+ resultDiv.classList.remove('hidden');
92
+ resultDiv.innerHTML = `
93
+ <h2>Prediction Result</h2>
94
+ <p class="prediction-text">Predicted Digit: ${prediction}</p>
95
+ <div class="confidence-bar">
96
+ <div class="confidence-level" style="width: 100%"></div>
97
+ </div>
98
+ `;
99
+ }
static/js/train.js ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ let ws;
2
+ let lossChart;
3
+ let accuracyChart;
4
+
5
+ function showTrainingForm(type) {
6
+ const singleForm = document.getElementById('single-model-form');
7
+ const compareForm = document.getElementById('compare-models-form');
8
+
9
+ if (type === 'single') {
10
+ singleForm.classList.remove('hidden');
11
+ compareForm.classList.add('hidden');
12
+ } else {
13
+ singleForm.classList.add('hidden');
14
+ compareForm.classList.remove('hidden');
15
+ }
16
+ }
17
+
18
+ function initializeCharts() {
19
+ const lossData = [{
20
+ name: 'Training Loss',
21
+ x: [],
22
+ y: [],
23
+ type: 'scatter'
24
+ }, {
25
+ name: 'Validation Loss',
26
+ x: [],
27
+ y: [],
28
+ type: 'scatter'
29
+ }];
30
+
31
+ const accuracyData = [{
32
+ name: 'Training Accuracy',
33
+ x: [],
34
+ y: [],
35
+ type: 'scatter'
36
+ }, {
37
+ name: 'Validation Accuracy',
38
+ x: [],
39
+ y: [],
40
+ type: 'scatter'
41
+ }];
42
+
43
+ Plotly.newPlot('loss-plot', lossData, {
44
+ title: 'Training and Validation Loss',
45
+ xaxis: { title: 'Iterations' },
46
+ yaxis: { title: 'Loss' }
47
+ });
48
+
49
+ Plotly.newPlot('accuracy-plot', accuracyData, {
50
+ title: 'Training and Validation Accuracy',
51
+ xaxis: { title: 'Iterations' },
52
+ yaxis: { title: 'Accuracy (%)' }
53
+ });
54
+ }
55
+
56
+ function updateCharts(data) {
57
+ const iteration = data.epoch * data.batch;
58
+
59
+ Plotly.extendTraces('loss-plot', {
60
+ x: [[iteration], [iteration]],
61
+ y: [[data.train_loss], [data.val_loss]]
62
+ }, [0, 1]);
63
+
64
+ Plotly.extendTraces('accuracy-plot', {
65
+ x: [[iteration], [iteration]],
66
+ y: [[data.train_acc], [data.val_acc]]
67
+ }, [0, 1]);
68
+
69
+ // Update training logs
70
+ const logsDiv = document.getElementById('training-logs');
71
+ logsDiv.innerHTML = `
72
+ <p>Epoch: ${data.epoch + 1}</p>
73
+ <p>Training Loss: ${data.train_loss.toFixed(4)}</p>
74
+ <p>Training Accuracy: ${data.train_acc.toFixed(2)}%</p>
75
+ <p>Validation Loss: ${data.val_loss.toFixed(4)}</p>
76
+ <p>Validation Accuracy: ${data.val_acc.toFixed(2)}%</p>
77
+ `;
78
+ }
79
+
80
+ async function trainSingleModel() {
81
+ const config = {
82
+ kernels: [
83
+ parseInt(document.getElementById('kernel1').value),
84
+ parseInt(document.getElementById('kernel2').value),
85
+ parseInt(document.getElementById('kernel3').value)
86
+ ],
87
+ optimizer: document.getElementById('optimizer').value,
88
+ batch_size: parseInt(document.getElementById('batch_size').value),
89
+ epochs: parseInt(document.getElementById('epochs').value)
90
+ };
91
+
92
+ // Show progress section and initialize charts
93
+ document.getElementById('training-progress').classList.remove('hidden');
94
+ initializeCharts();
95
+
96
+ // Connect to WebSocket
97
+ ws = new WebSocket(`ws://${window.location.host}/ws/train`);
98
+ ws.onmessage = function(event) {
99
+ const data = JSON.parse(event.data);
100
+ updateCharts(data);
101
+ };
102
+
103
+ try {
104
+ const response = await fetch('/api/train_single', {
105
+ method: 'POST',
106
+ headers: {
107
+ 'Content-Type': 'application/json',
108
+ },
109
+ body: JSON.stringify(config)
110
+ });
111
+ const data = await response.json();
112
+
113
+ if (data.status === 'success') {
114
+ alert('Training completed successfully!');
115
+ }
116
+ } catch (error) {
117
+ console.error('Error:', error);
118
+ alert('Error during training. Please check console for details.');
119
+ }
120
+ }
121
+
122
+ async function compareModels() {
123
+ const config = {
124
+ model1: {
125
+ kernels: [
126
+ parseInt(document.getElementById('model1_kernel1').value),
127
+ parseInt(document.getElementById('model1_kernel2').value),
128
+ parseInt(document.getElementById('model1_kernel3').value)
129
+ ],
130
+ optimizer: document.getElementById('model1_optimizer').value,
131
+ batch_size: parseInt(document.getElementById('model1_batch_size').value),
132
+ epochs: parseInt(document.getElementById('model1_epochs').value)
133
+ },
134
+ model2: {
135
+ kernels: [
136
+ parseInt(document.getElementById('model2_kernel1').value),
137
+ parseInt(document.getElementById('model2_kernel2').value),
138
+ parseInt(document.getElementById('model2_kernel3').value)
139
+ ],
140
+ optimizer: document.getElementById('model2_optimizer').value,
141
+ batch_size: parseInt(document.getElementById('model2_batch_size').value),
142
+ epochs: parseInt(document.getElementById('model2_epochs').value)
143
+ }
144
+ };
145
+
146
+ // Show comparison progress section
147
+ document.getElementById('comparison-progress').classList.remove('hidden');
148
+ initializeComparisonCharts();
149
+
150
+ try {
151
+ const response = await fetch('/api/train_compare', {
152
+ method: 'POST',
153
+ headers: {
154
+ 'Content-Type': 'application/json',
155
+ },
156
+ body: JSON.stringify(config)
157
+ });
158
+ const data = await response.json();
159
+
160
+ if (data.status === 'success') {
161
+ displayComparisonResults(data);
162
+ alert('Model comparison completed successfully!');
163
+ }
164
+ } catch (error) {
165
+ console.error('Error:', error);
166
+ alert('Error during model comparison. Please check console for details.');
167
+ }
168
+ }
169
+
170
+ function initializeComparisonCharts() {
171
+ const lossData = [{
172
+ name: 'Model 1 Loss',
173
+ x: [],
174
+ y: [],
175
+ type: 'scatter'
176
+ }, {
177
+ name: 'Model 2 Loss',
178
+ x: [],
179
+ y: [],
180
+ type: 'scatter'
181
+ }];
182
+
183
+ const accuracyData = [{
184
+ name: 'Model 1 Accuracy',
185
+ x: [],
186
+ y: [],
187
+ type: 'scatter'
188
+ }, {
189
+ name: 'Model 2 Accuracy',
190
+ x: [],
191
+ y: [],
192
+ type: 'scatter'
193
+ }];
194
+
195
+ Plotly.newPlot('comparison-loss-plot', lossData, {
196
+ title: 'Loss Comparison',
197
+ xaxis: { title: 'Iterations' },
198
+ yaxis: { title: 'Loss' }
199
+ });
200
+
201
+ Plotly.newPlot('comparison-accuracy-plot', accuracyData, {
202
+ title: 'Accuracy Comparison',
203
+ xaxis: { title: 'Iterations' },
204
+ yaxis: { title: 'Accuracy (%)' }
205
+ });
206
+ }
207
+
208
+ function displayComparisonResults(data) {
209
+ const logsDiv = document.getElementById('comparison-logs');
210
+ logsDiv.innerHTML = `
211
+ <div class="comparison-model">
212
+ <h4>Model 1</h4>
213
+ <p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
214
+ <p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
215
+ <p>Model Name: ${data.model1_results.model_name}</p>
216
+ </div>
217
+ <div class="comparison-model">
218
+ <h4>Model 2</h4>
219
+ <p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
220
+ <p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
221
+ <p>Model Name: ${data.model2_results.model_name}</p>
222
+ </div>
223
+ `;
224
+ }
225
+
226
+ function displayResults(data) {
227
+ const resultsDiv = document.getElementById('training-results');
228
+ // Display training results
229
+ }
static/js/train_compare.js ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ let ws;
2
+
3
+ function initializeComparisonCharts() {
4
+ const lossData = [{
5
+ name: 'Model 1 Loss',
6
+ x: [],
7
+ y: [],
8
+ type: 'scatter'
9
+ }, {
10
+ name: 'Model 2 Loss',
11
+ x: [],
12
+ y: [],
13
+ type: 'scatter'
14
+ }];
15
+
16
+ const accuracyData = [{
17
+ name: 'Model 1 Accuracy',
18
+ x: [],
19
+ y: [],
20
+ type: 'scatter'
21
+ }, {
22
+ name: 'Model 2 Accuracy',
23
+ x: [],
24
+ y: [],
25
+ type: 'scatter'
26
+ }];
27
+
28
+ Plotly.newPlot('comparison-loss-plot', lossData, {
29
+ title: 'Loss Comparison',
30
+ xaxis: { title: 'Iterations' },
31
+ yaxis: { title: 'Loss' }
32
+ });
33
+
34
+ Plotly.newPlot('comparison-accuracy-plot', accuracyData, {
35
+ title: 'Accuracy Comparison',
36
+ xaxis: { title: 'Iterations' },
37
+ yaxis: { title: 'Accuracy (%)' }
38
+ });
39
+ }
40
+
41
+ async function compareModels() {
42
+ const config = {
43
+ model1: {
44
+ kernels: [
45
+ parseInt(document.getElementById('model1_kernel1').value),
46
+ parseInt(document.getElementById('model1_kernel2').value),
47
+ parseInt(document.getElementById('model1_kernel3').value)
48
+ ],
49
+ optimizer: document.getElementById('model1_optimizer').value,
50
+ batch_size: parseInt(document.getElementById('model1_batch_size').value),
51
+ epochs: parseInt(document.getElementById('model1_epochs').value)
52
+ },
53
+ model2: {
54
+ kernels: [
55
+ parseInt(document.getElementById('model2_kernel1').value),
56
+ parseInt(document.getElementById('model2_kernel2').value),
57
+ parseInt(document.getElementById('model2_kernel3').value)
58
+ ],
59
+ optimizer: document.getElementById('model2_optimizer').value,
60
+ batch_size: parseInt(document.getElementById('model2_batch_size').value),
61
+ epochs: parseInt(document.getElementById('model2_epochs').value)
62
+ }
63
+ };
64
+
65
+ // Show comparison progress section
66
+ document.getElementById('comparison-progress').classList.remove('hidden');
67
+ initializeComparisonCharts();
68
+
69
+ try {
70
+ const response = await fetch('/api/train_compare', {
71
+ method: 'POST',
72
+ headers: {
73
+ 'Content-Type': 'application/json',
74
+ },
75
+ body: JSON.stringify(config)
76
+ });
77
+ const data = await response.json();
78
+
79
+ if (data.status === 'success') {
80
+ displayComparisonResults(data);
81
+ alert('Model comparison completed successfully!');
82
+ }
83
+ } catch (error) {
84
+ console.error('Error:', error);
85
+ alert('Error during model comparison. Please check console for details.');
86
+ }
87
+ }
88
+
89
+ function displayComparisonResults(data) {
90
+ const logsDiv = document.getElementById('comparison-logs');
91
+ logsDiv.innerHTML = `
92
+ <div class="comparison-model">
93
+ <h4>Model 1</h4>
94
+ <p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
95
+ <p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
96
+ <p>Model Name: ${data.model1_results.model_name}</p>
97
+ </div>
98
+ <div class="comparison-model">
99
+ <h4>Model 2</h4>
100
+ <p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
101
+ <p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
102
+ <p>Model Name: ${data.model2_results.model_name}</p>
103
+ </div>
104
+ `;
105
+ }
static/js/train_single.js ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ let ws;
2
+
3
+ function initializeCharts() {
4
+ const lossData = [{
5
+ name: 'Training Loss',
6
+ x: [],
7
+ y: [],
8
+ type: 'scatter'
9
+ }, {
10
+ name: 'Validation Loss',
11
+ x: [],
12
+ y: [],
13
+ type: 'scatter'
14
+ }];
15
+
16
+ const accuracyData = [{
17
+ name: 'Training Accuracy',
18
+ x: [],
19
+ y: [],
20
+ type: 'scatter'
21
+ }, {
22
+ name: 'Validation Accuracy',
23
+ x: [],
24
+ y: [],
25
+ type: 'scatter'
26
+ }];
27
+
28
+ Plotly.newPlot('loss-plot', lossData, {
29
+ title: 'Training and Validation Loss',
30
+ xaxis: { title: 'Iterations' },
31
+ yaxis: { title: 'Loss' }
32
+ });
33
+
34
+ Plotly.newPlot('accuracy-plot', accuracyData, {
35
+ title: 'Training and Validation Accuracy',
36
+ xaxis: { title: 'Iterations' },
37
+ yaxis: { title: 'Accuracy (%)' }
38
+ });
39
+ }
40
+
41
+ function updateCharts(data) {
42
+ const iteration = data.epoch * data.batch;
43
+
44
+ Plotly.extendTraces('loss-plot', {
45
+ x: [[iteration], [iteration]],
46
+ y: [[data.train_loss], [data.val_loss]]
47
+ }, [0, 1]);
48
+
49
+ Plotly.extendTraces('accuracy-plot', {
50
+ x: [[iteration], [iteration]],
51
+ y: [[data.train_acc], [data.val_acc]]
52
+ }, [0, 1]);
53
+
54
+ // Update training logs
55
+ const logsDiv = document.getElementById('training-logs');
56
+ logsDiv.innerHTML = `
57
+ <p>Epoch: ${data.epoch + 1}</p>
58
+ <p>Training Loss: ${data.train_loss.toFixed(4)}</p>
59
+ <p>Training Accuracy: ${data.train_acc.toFixed(2)}%</p>
60
+ <p>Validation Loss: ${data.val_loss.toFixed(4)}</p>
61
+ <p>Validation Accuracy: ${data.val_acc.toFixed(2)}%</p>
62
+ `;
63
+ }
64
+
65
+ async function trainModel() {
66
+ console.log("Training started..."); // Debug log
67
+ const config = {
68
+ kernels: [
69
+ parseInt(document.getElementById('kernel1').value),
70
+ parseInt(document.getElementById('kernel2').value),
71
+ parseInt(document.getElementById('kernel3').value)
72
+ ],
73
+ optimizer: document.getElementById('optimizer').value,
74
+ batch_size: parseInt(document.getElementById('batch_size').value),
75
+ epochs: parseInt(document.getElementById('epochs').value)
76
+ };
77
+
78
+ console.log("Config:", config); // Debug log
79
+
80
+ // Show progress section and initialize charts
81
+ document.getElementById('training-progress').classList.remove('hidden');
82
+ initializeCharts();
83
+
84
+ try {
85
+ // Connect to WebSocket
86
+ console.log("Connecting to WebSocket..."); // Debug log
87
+ ws = new WebSocket(`ws://${window.location.host}/ws/train`);
88
+
89
+ ws.onopen = function() {
90
+ console.log("WebSocket connection established");
91
+ // Send configuration once connected
92
+ ws.send(JSON.stringify(config));
93
+ console.log("Config sent to server"); // Debug log
94
+ };
95
+
96
+ ws.onmessage = function(event) {
97
+ console.log("Received message:", event.data); // Debug log
98
+ const data = JSON.parse(event.data);
99
+ if (data.status === "completed") {
100
+ alert('Training completed successfully!');
101
+ } else if (data.status === "error") {
102
+ alert('Error during training: ' + data.message);
103
+ } else {
104
+ updateCharts(data);
105
+ }
106
+ };
107
+
108
+ ws.onerror = function(error) {
109
+ console.error('WebSocket error:', error);
110
+ alert('Error connecting to training server');
111
+ };
112
+
113
+ ws.onclose = function() {
114
+ console.log('WebSocket connection closed');
115
+ };
116
+
117
+ } catch (error) {
118
+ console.error('Error:', error);
119
+ alert('Error during training: ' + error.message);
120
+ }
121
+ }
templates/index.html ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>MNIST Digit Classification</title>
7
+ <link rel="stylesheet" href="{{ url_for('static', path='/css/style.css') }}">
8
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
9
+ </head>
10
+ <body>
11
+ <div class="container">
12
+ <h1>MNIST Digit Classification</h1>
13
+ <div class="card">
14
+ <p class="hero-text">Train and test neural networks for handwritten digit recognition using the MNIST dataset.</p>
15
+ </div>
16
+ <div class="button-container">
17
+ <a href="/train" class="btn">Train Model</a>
18
+ <a href="/inference" class="btn">Test Model</a>
19
+ </div>
20
+
21
+ <div class="features-grid">
22
+ <div class="card">
23
+ <h3>Train Models</h3>
24
+ <p>Configure and train custom neural networks with different architectures.</p>
25
+ </div>
26
+ <div class="card">
27
+ <h3>Compare Performance</h3>
28
+ <p>Train multiple models simultaneously and compare their performance.</p>
29
+ </div>
30
+ <div class="card">
31
+ <h3>Real-time Visualization</h3>
32
+ <p>Monitor training progress with live loss and accuracy curves.</p>
33
+ </div>
34
+ <div class="card">
35
+ <h3>Interactive Testing</h3>
36
+ <p>Draw digits and test the model's prediction capabilities in real-time.</p>
37
+ </div>
38
+ </div>
39
+ </div>
40
+ </body>
41
+ </html>
templates/inference.html ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Test Model - MNIST</title>
7
+ <link rel="stylesheet" href="{{ url_for('static', path='/css/style.css') }}">
8
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
9
+ </head>
10
+ <body>
11
+ <div class="container">
12
+ <h1>Test Your Model</h1>
13
+ <div class="card">
14
+ <p>Draw a digit (0-9) in the canvas below and click "Predict" to test the model.</p>
15
+
16
+ <div class="form-group">
17
+ <label for="model-select">Select Model</label>
18
+ <select id="model-select" required>
19
+ {% if available_models %}
20
+ {% for model in available_models %}
21
+ <option value="{{ model }}">{{ model }}</option>
22
+ {% endfor %}
23
+ {% else %}
24
+ <option value="">No models available - Train a model first</option>
25
+ {% endif %}
26
+ </select>
27
+ </div>
28
+ </div>
29
+
30
+ <canvas id="drawing-canvas" width="280" height="280"></canvas>
31
+ <div class="button-container">
32
+ <button onclick="clearCanvas()" class="btn">Clear Canvas</button>
33
+ <button onclick="predict()" class="btn" {% if not available_models %}disabled{% endif %}>Predict</button>
34
+ </div>
35
+ <div id="prediction-result" class="card hidden">
36
+ <!-- Prediction result will be displayed here -->
37
+ </div>
38
+ </div>
39
+ <script src="{{ url_for('static', path='/js/inference.js') }}"></script>
40
+ </body>
41
+ </html>
templates/train.html ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Train Models - MNIST</title>
7
+ <link rel="stylesheet" href="{{ url_for('static', path='/css/style.css') }}">
8
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Roboto+Mono&display=swap" rel="stylesheet">
9
+ </head>
10
+ <body>
11
+ <div class="container">
12
+ <h1>Train Your Model</h1>
13
+ <div class="button-container">
14
+ <a href="/train/single" class="btn">Train Single Model</a>
15
+ <a href="/train/compare" class="btn">Compare Models</a>
16
+ </div>
17
+ </div>
18
+ </body>
19
+ </html>
templates/train_compare.html ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Compare Models - MNIST</title>
7
+ <link rel="stylesheet" href="{{ url_for('static', path='/css/style.css') }}">
8
+ <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
9
+ </head>
10
+ <body>
11
+ <div class="container">
12
+ <h1>Compare Models</h1>
13
+ <div class="models-grid">
14
+ <!-- Model 1 Configuration -->
15
+ <div class="model-config">
16
+ <h3>Model 1</h3>
17
+ <div class="network-config">
18
+ <h4>Network Architecture</h4>
19
+ <div class="block-config">
20
+ <div class="block">
21
+ <label for="model1_block1">Block-1:</label>
22
+ <select id="model1_block1" name="block1" class="form-select">
23
+ <option value="8">8</option>
24
+ <option value="16">16</option>
25
+ <option value="32" selected>32</option>
26
+ <option value="64">64</option>
27
+ <option value="128">128</option>
28
+ </select>
29
+ </div>
30
+
31
+ <div class="block">
32
+ <label for="model1_block2">Block-2:</label>
33
+ <select id="model1_block2" name="block2" class="form-select">
34
+ <option value="8">8</option>
35
+ <option value="16">16</option>
36
+ <option value="32">32</option>
37
+ <option value="64" selected>64</option>
38
+ <option value="128">128</option>
39
+ </select>
40
+ </div>
41
+
42
+ <div class="block">
43
+ <label for="model1_block3">Block-3:</label>
44
+ <select id="model1_block3" name="block3" class="form-select">
45
+ <option value="8">8</option>
46
+ <option value="16">16</option>
47
+ <option value="32">32</option>
48
+ <option value="64">64</option>
49
+ <option value="128" selected>128</option>
50
+ </select>
51
+ </div>
52
+ </div>
53
+ </div>
54
+ <div class="training-config">
55
+ <div class="config-item">
56
+ <label for="model1_optimizer">Optimizer:</label>
57
+ <select id="model1_optimizer" name="optimizer">
58
+ <option value="SGD" selected>SGD</option>
59
+ <option value="Adam">Adam</option>
60
+ </select>
61
+ </div>
62
+ <div class="config-item">
63
+ <label for="model1_batch_size">Batch Size:</label>
64
+ <select id="model1_batch_size" name="batch_size">
65
+ <option value="32">32</option>
66
+ <option value="64" selected>64</option>
67
+ <option value="128">128</option>
68
+ </select>
69
+ </div>
70
+ </div>
71
+ </div>
72
+
73
+ <!-- Model 2 Configuration -->
74
+ <div class="model-config">
75
+ <h3>Model 2</h3>
76
+ <div class="network-config">
77
+ <h4>Network Architecture</h4>
78
+ <div class="block-config">
79
+ <div class="block">
80
+ <label for="model2_block1">Block-1:</label>
81
+ <select id="model2_block1" name="block1" class="form-select">
82
+ <option value="8">8</option>
83
+ <option value="16">16</option>
84
+ <option value="32" selected>32</option>
85
+ <option value="64">64</option>
86
+ <option value="128">128</option>
87
+ </select>
88
+ </div>
89
+
90
+ <div class="block">
91
+ <label for="model2_block2">Block-2:</label>
92
+ <select id="model2_block2" name="block2" class="form-select">
93
+ <option value="8">8</option>
94
+ <option value="16">16</option>
95
+ <option value="32">32</option>
96
+ <option value="64" selected>64</option>
97
+ <option value="128">128</option>
98
+ </select>
99
+ </div>
100
+
101
+ <div class="block">
102
+ <label for="model2_block3">Block-3:</label>
103
+ <select id="model2_block3" name="block3" class="form-select">
104
+ <option value="8">8</option>
105
+ <option value="16">16</option>
106
+ <option value="32">32</option>
107
+ <option value="64">64</option>
108
+ <option value="128" selected>128</option>
109
+ </select>
110
+ </div>
111
+ </div>
112
+ </div>
113
+ <div class="training-config">
114
+ <div class="config-item">
115
+ <label for="model2_optimizer">Optimizer:</label>
116
+ <select id="model2_optimizer" name="optimizer">
117
+ <option value="SGD" selected>SGD</option>
118
+ <option value="Adam">Adam</option>
119
+ </select>
120
+ </div>
121
+ <div class="config-item">
122
+ <label for="model2_batch_size">Batch Size:</label>
123
+ <select id="model2_batch_size" name="batch_size">
124
+ <option value="32">32</option>
125
+ <option value="64" selected>64</option>
126
+ <option value="128">128</option>
127
+ </select>
128
+ </div>
129
+ </div>
130
+ </div>
131
+ </div>
132
+
133
+ <!-- Training Controls -->
134
+ <div class="controls">
135
+ <button id="startComparison" onclick="startComparison()">Start Comparison</button>
136
+ <button id="stopComparison" onclick="stopComparison()" disabled>Stop Comparison</button>
137
+ </div>
138
+
139
+ <!-- Training Progress -->
140
+ <div class="charts-container">
141
+ <div id="lossChart"></div>
142
+ <div id="accuracyChart"></div>
143
+ </div>
144
+ </div>
145
+
146
+ <style>
147
+ .container {
148
+ max-width: 1200px;
149
+ margin: 0 auto;
150
+ padding: 20px;
151
+ }
152
+
153
+ .models-grid {
154
+ display: grid;
155
+ grid-template-columns: 1fr 1fr;
156
+ gap: 20px;
157
+ margin-bottom: 20px;
158
+ }
159
+
160
+ .model-config {
161
+ padding: 20px;
162
+ border: 1px solid #ddd;
163
+ border-radius: 5px;
164
+ margin-bottom: 20px;
165
+ }
166
+
167
+ .network-config {
168
+ margin-bottom: 20px;
169
+ }
170
+
171
+ .network-config h4 {
172
+ margin: 0 0 15px 0;
173
+ font-size: 1.1em;
174
+ }
175
+
176
+ .block-config {
177
+ display: flex;
178
+ justify-content: space-between;
179
+ gap: 20px;
180
+ }
181
+
182
+ .block {
183
+ flex: 1;
184
+ }
185
+
186
+ .block label {
187
+ display: block;
188
+ margin-bottom: 5px;
189
+ font-weight: bold;
190
+ }
191
+
192
+ .training-config {
193
+ display: flex;
194
+ gap: 20px;
195
+ }
196
+
197
+ .config-item {
198
+ flex: 1;
199
+ }
200
+
201
+ .config-item label {
202
+ display: block;
203
+ margin-bottom: 5px;
204
+ font-weight: bold;
205
+ }
206
+
207
+ select {
208
+ width: 100%;
209
+ padding: 8px;
210
+ border: 1px solid #ddd;
211
+ border-radius: 4px;
212
+ }
213
+
214
+ .controls {
215
+ margin: 20px 0;
216
+ text-align: center;
217
+ }
218
+
219
+ button {
220
+ padding: 10px 20px;
221
+ margin-right: 10px;
222
+ border: none;
223
+ border-radius: 4px;
224
+ background-color: #007bff;
225
+ color: white;
226
+ cursor: pointer;
227
+ }
228
+
229
+ button:disabled {
230
+ background-color: #ccc;
231
+ cursor: not-allowed;
232
+ }
233
+
234
+ .charts-container {
235
+ display: flex;
236
+ flex-direction: column;
237
+ gap: 20px;
238
+ margin-top: 20px;
239
+ }
240
+
241
+ #lossChart, #accuracyChart {
242
+ height: 400px;
243
+ width: 100%;
244
+ }
245
+
246
+ h4 {
247
+ margin: 0 0 10px 0;
248
+ }
249
+
250
+ .section-title {
251
+ color: white;
252
+ font-weight: bold;
253
+ margin: 0 0 10px 0;
254
+ font-size: 1.1em;
255
+ text-transform: uppercase;
256
+ }
257
+
258
+ .network-config .section-title {
259
+ margin: 0 0 15px 0;
260
+ }
261
+
262
+ .config-item .section-title {
263
+ margin-bottom: 5px;
264
+ }
265
+ </style>
266
+
267
+ <script>
268
+ let ws;
269
+ let lossChart;
270
+ let accuracyChart;
271
+
272
+ // Initialize charts
273
+ document.addEventListener('DOMContentLoaded', function() {
274
+ // Loss chart configuration
275
+ const lossData = [
276
+ {
277
+ x: [],
278
+ y: [],
279
+ name: 'Model 1 Training Loss',
280
+ type: 'scatter'
281
+ },
282
+ {
283
+ x: [],
284
+ y: [],
285
+ name: 'Model 2 Training Loss',
286
+ type: 'scatter'
287
+ }
288
+ ];
289
+
290
+ const lossLayout = {
291
+ title: 'Training Loss Comparison',
292
+ xaxis: {
293
+ title: 'Iterations',
294
+ rangemode: 'nonnegative'
295
+ },
296
+ yaxis: {
297
+ title: 'Loss',
298
+ rangemode: 'nonnegative'
299
+ }
300
+ };
301
+
302
+ // Accuracy chart configuration
303
+ const accuracyData = [
304
+ {
305
+ x: [],
306
+ y: [],
307
+ name: 'Model 1 Training Accuracy',
308
+ type: 'scatter'
309
+ },
310
+ {
311
+ x: [],
312
+ y: [],
313
+ name: 'Model 2 Training Accuracy',
314
+ type: 'scatter'
315
+ }
316
+ ];
317
+
318
+ const accuracyLayout = {
319
+ title: 'Training Accuracy Comparison',
320
+ xaxis: {
321
+ title: 'Iterations',
322
+ rangemode: 'nonnegative'
323
+ },
324
+ yaxis: {
325
+ title: 'Accuracy (%)',
326
+ range: [0, 100]
327
+ }
328
+ };
329
+
330
+ // Create charts
331
+ Plotly.newPlot('lossChart', lossData, lossLayout);
332
+ Plotly.newPlot('accuracyChart', accuracyData, accuracyLayout);
333
+ });
334
+
335
+ function startComparison() {
336
+ // Disable start button and enable stop button
337
+ document.getElementById('startComparison').disabled = true;
338
+ document.getElementById('stopComparison').disabled = false;
339
+
340
+ // Get configuration for both models
341
+ const model1Config = {
342
+ block1: parseInt(document.getElementById('model1_block1').value),
343
+ block2: parseInt(document.getElementById('model1_block2').value),
344
+ block3: parseInt(document.getElementById('model1_block3').value),
345
+ optimizer: document.getElementById('model1_optimizer').value,
346
+ batch_size: parseInt(document.getElementById('model1_batch_size').value)
347
+ };
348
+
349
+ const model2Config = {
350
+ block1: parseInt(document.getElementById('model2_block1').value),
351
+ block2: parseInt(document.getElementById('model2_block2').value),
352
+ block3: parseInt(document.getElementById('model2_block3').value),
353
+ optimizer: document.getElementById('model2_optimizer').value,
354
+ batch_size: parseInt(document.getElementById('model2_batch_size').value)
355
+ };
356
+
357
+ // Setup WebSocket connection
358
+ ws = new WebSocket(`ws://${window.location.host}/ws/compare`);
359
+
360
+ ws.onmessage = function(event) {
361
+ const data = JSON.parse(event.data);
362
+
363
+ if (data.type === 'training_update') {
364
+ const modelIndex = data.data.model_id - 1; // 0 for model1, 1 for model2
365
+
366
+ // Update training metrics
367
+ Plotly.extendTraces('lossChart', {
368
+ x: [[data.data.step]],
369
+ y: [[data.data.train_loss]]
370
+ }, [modelIndex]);
371
+
372
+ Plotly.extendTraces('accuracyChart', {
373
+ x: [[data.data.step]],
374
+ y: [[data.data.train_acc]]
375
+ }, [modelIndex]);
376
+ }
377
+ else if (data.type === 'validation_update') {
378
+ const modelIndex = data.data.model_id - 1;
379
+
380
+ // Add validation points
381
+ Plotly.addTraces('lossChart', {
382
+ x: [data.data.step],
383
+ y: [data.data.val_loss],
384
+ name: `Model ${data.data.model_id} Validation Loss`,
385
+ mode: 'markers',
386
+ marker: { size: 8 }
387
+ });
388
+
389
+ Plotly.addTraces('accuracyChart', {
390
+ x: [data.data.step],
391
+ y: [data.data.val_acc],
392
+ name: `Model ${data.data.model_id} Validation Accuracy`,
393
+ mode: 'markers',
394
+ marker: { size: 8 }
395
+ });
396
+ }
397
+ else if (data.type === 'comparison_complete') {
398
+ document.getElementById('startComparison').disabled = false;
399
+ document.getElementById('stopComparison').disabled = true;
400
+ }
401
+ };
402
+
403
+ // Start comparison
404
+ ws.send(JSON.stringify({
405
+ type: 'start_comparison',
406
+ model1: model1Config,
407
+ model2: model2Config
408
+ }));
409
+ }
410
+
411
+ function stopComparison() {
412
+ if (ws) {
413
+ ws.close();
414
+ }
415
+ document.getElementById('startComparison').disabled = false;
416
+ document.getElementById('stopComparison').disabled = true;
417
+ }
418
+ </script>
419
+ </body>
420
+ </html>
templates/train_single.html ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Train Single Model - MNIST</title>
7
+ <link rel="stylesheet" href="{{ url_for('static', path='/css/style.css') }}">
8
+ <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
9
+ </head>
10
+ <body>
11
+ <div class="container">
12
+ <h1>Train Single Model</h1>
13
+
14
+ <!-- Network Architecture Section -->
15
+ <div class="model-config">
16
+ <h3>Model Configuration</h3>
17
+ <div class="network-config">
18
+ <h4>Network Architecture</h4>
19
+ <div class="block-config">
20
+ <div class="block">
21
+ <label for="block1">Block-1:</label>
22
+ <select id="block1" name="block1" class="form-select">
23
+ <option value="8">8</option>
24
+ <option value="16">16</option>
25
+ <option value="32" selected>32</option>
26
+ <option value="64">64</option>
27
+ <option value="128">128</option>
28
+ </select>
29
+ </div>
30
+
31
+ <div class="block">
32
+ <label for="block2">Block-2:</label>
33
+ <select id="block2" name="block2" class="form-select">
34
+ <option value="8">8</option>
35
+ <option value="16">16</option>
36
+ <option value="32">32</option>
37
+ <option value="64" selected>64</option>
38
+ <option value="128">128</option>
39
+ </select>
40
+ </div>
41
+
42
+ <div class="block">
43
+ <label for="block3">Block-3:</label>
44
+ <select id="block3" name="block3" class="form-select">
45
+ <option value="8">8</option>
46
+ <option value="16">16</option>
47
+ <option value="32">32</option>
48
+ <option value="64">64</option>
49
+ <option value="128" selected>128</option>
50
+ </select>
51
+ </div>
52
+ </div>
53
+ </div>
54
+ <div class="training-config">
55
+ <div class="config-item">
56
+ <label for="optimizer">Optimizer:</label>
57
+ <select id="optimizer" name="optimizer">
58
+ <option value="SGD" selected>SGD</option>
59
+ <option value="Adam">Adam</option>
60
+ </select>
61
+ </div>
62
+ <div class="config-item">
63
+ <label for="batch_size">Batch Size:</label>
64
+ <select id="batch_size" name="batch_size">
65
+ <option value="32">32</option>
66
+ <option value="64" selected>64</option>
67
+ <option value="128">128</option>
68
+ </select>
69
+ </div>
70
+ </div>
71
+ </div>
72
+
73
+ <!-- Training Controls -->
74
+ <div class="controls">
75
+ <button id="startTraining" onclick="startTraining()">Start Training</button>
76
+ <button id="stopTraining" onclick="stopTraining()" disabled>Stop Training</button>
77
+ </div>
78
+
79
+ <!-- Training Progress -->
80
+ <div class="charts-container">
81
+ <div id="lossChart"></div>
82
+ <div id="accuracyChart"></div>
83
+ </div>
84
+ </div>
85
+
86
+ <script>
87
+ let ws;
88
+ let lossChart;
89
+ let accuracyChart;
90
+
91
+ // Initialize charts
92
+ document.addEventListener('DOMContentLoaded', function() {
93
+ // Loss chart configuration
94
+ const lossData = [
95
+ {
96
+ x: [],
97
+ y: [],
98
+ name: 'Training Loss',
99
+ type: 'scatter'
100
+ },
101
+ {
102
+ x: [],
103
+ y: [],
104
+ name: 'Validation Loss',
105
+ type: 'scatter'
106
+ }
107
+ ];
108
+
109
+ const lossLayout = {
110
+ title: 'Loss',
111
+ xaxis: {
112
+ title: 'Iterations',
113
+ rangemode: 'nonnegative'
114
+ },
115
+ yaxis: {
116
+ title: 'Loss',
117
+ rangemode: 'nonnegative'
118
+ }
119
+ };
120
+
121
+ // Accuracy chart configuration
122
+ const accuracyData = [
123
+ {
124
+ x: [],
125
+ y: [],
126
+ name: 'Training Accuracy',
127
+ type: 'scatter'
128
+ },
129
+ {
130
+ x: [],
131
+ y: [],
132
+ name: 'Validation Accuracy',
133
+ type: 'scatter'
134
+ }
135
+ ];
136
+
137
+ const accuracyLayout = {
138
+ title: 'Accuracy',
139
+ xaxis: {
140
+ title: 'Iterations',
141
+ rangemode: 'nonnegative'
142
+ },
143
+ yaxis: {
144
+ title: 'Accuracy (%)',
145
+ range: [0, 100]
146
+ }
147
+ };
148
+
149
+ // Create charts
150
+ Plotly.newPlot('lossChart', lossData, lossLayout);
151
+ Plotly.newPlot('accuracyChart', accuracyData, accuracyLayout);
152
+ });
153
+
154
+ function startTraining() {
155
+ // Disable start button and enable stop button
156
+ document.getElementById('startTraining').disabled = true;
157
+ document.getElementById('stopTraining').disabled = false;
158
+
159
+ // Clear previous charts
160
+ Plotly.purge('lossChart');
161
+ Plotly.purge('accuracyChart');
162
+
163
+ // Initialize new charts
164
+ const lossData = [
165
+ {
166
+ x: [],
167
+ y: [],
168
+ name: 'Training Loss',
169
+ type: 'scatter'
170
+ },
171
+ {
172
+ x: [],
173
+ y: [],
174
+ name: 'Validation Loss',
175
+ type: 'scatter'
176
+ }
177
+ ];
178
+
179
+ const accuracyData = [
180
+ {
181
+ x: [],
182
+ y: [],
183
+ name: 'Training Accuracy',
184
+ type: 'scatter'
185
+ },
186
+ {
187
+ x: [],
188
+ y: [],
189
+ name: 'Validation Accuracy',
190
+ type: 'scatter'
191
+ }
192
+ ];
193
+
194
+ Plotly.newPlot('lossChart', lossData, {
195
+ title: 'Loss',
196
+ xaxis: { title: 'Iterations', rangemode: 'nonnegative' },
197
+ yaxis: { title: 'Loss', rangemode: 'nonnegative' }
198
+ });
199
+
200
+ Plotly.newPlot('accuracyChart', accuracyData, {
201
+ title: 'Accuracy',
202
+ xaxis: { title: 'Iterations', rangemode: 'nonnegative' },
203
+ yaxis: { title: 'Accuracy (%)', range: [0, 100] }
204
+ });
205
+
206
+ // Setup WebSocket connection
207
+ ws = new WebSocket(`ws://${window.location.host}/ws/train`);
208
+
209
+ ws.onopen = function() {
210
+ console.log("WebSocket connection established");
211
+ // Send configuration through WebSocket
212
+ const config = {
213
+ block1: parseInt(document.getElementById('block1').value),
214
+ block2: parseInt(document.getElementById('block2').value),
215
+ block3: parseInt(document.getElementById('block3').value),
216
+ optimizer: document.getElementById('optimizer').value,
217
+ batch_size: parseInt(document.getElementById('batch_size').value),
218
+ epochs: 1 // Add default epochs value
219
+ };
220
+ ws.send(JSON.stringify(config));
221
+ };
222
+
223
+ ws.onerror = function(error) {
224
+ console.error("WebSocket error:", error);
225
+ stopTraining();
226
+ alert("Error connecting to training server");
227
+ };
228
+
229
+ ws.onclose = function() {
230
+ console.log("WebSocket connection closed");
231
+ stopTraining();
232
+ };
233
+
234
+ ws.onmessage = function(event) {
235
+ const data = JSON.parse(event.data);
236
+
237
+ if (data.type === 'training_update') {
238
+ // Update training metrics (trace index 0)
239
+ Plotly.extendTraces('lossChart', {
240
+ x: [[data.data.step]],
241
+ y: [[data.data.train_loss]]
242
+ }, [0]);
243
+
244
+ Plotly.extendTraces('accuracyChart', {
245
+ x: [[data.data.step]],
246
+ y: [[data.data.train_acc]]
247
+ }, [0]);
248
+ }
249
+ else if (data.type === 'validation_update') {
250
+ // Update validation metrics (trace index 1)
251
+ Plotly.extendTraces('lossChart', {
252
+ x: [[data.data.step]],
253
+ y: [[data.data.val_loss]]
254
+ }, [1]);
255
+
256
+ Plotly.extendTraces('accuracyChart', {
257
+ x: [[data.data.step]],
258
+ y: [[data.data.val_acc]]
259
+ }, [1]);
260
+ }
261
+ else if (data.type === 'training_complete') {
262
+ alert(data.data.message);
263
+ stopTraining();
264
+ }
265
+ else if (data.type === 'training_error') {
266
+ alert(data.data.message);
267
+ stopTraining();
268
+ }
269
+ };
270
+ }
271
+
272
+ function stopTraining() {
273
+ if (ws) {
274
+ ws.close();
275
+ }
276
+ document.getElementById('startTraining').disabled = false;
277
+ document.getElementById('stopTraining').disabled = true;
278
+ }
279
+ </script>
280
+
281
+ <style>
282
+ .container {
283
+ max-width: 1200px;
284
+ margin: 0 auto;
285
+ padding: 20px;
286
+ }
287
+
288
+ .model-config {
289
+ padding: 20px;
290
+ border: 1px solid #ddd;
291
+ border-radius: 5px;
292
+ margin-bottom: 20px;
293
+ }
294
+
295
+ .network-config {
296
+ margin-bottom: 20px;
297
+ }
298
+
299
+ .network-config h4 {
300
+ margin: 0 0 15px 0;
301
+ font-size: 1.1em;
302
+ }
303
+
304
+ .block-config {
305
+ display: flex;
306
+ justify-content: space-between;
307
+ gap: 20px;
308
+ }
309
+
310
+ .block {
311
+ flex: 1;
312
+ }
313
+
314
+ .block label {
315
+ display: block;
316
+ margin-bottom: 5px;
317
+ font-weight: bold;
318
+ }
319
+
320
+ .training-config {
321
+ display: flex;
322
+ gap: 20px;
323
+ }
324
+
325
+ .config-item {
326
+ flex: 1;
327
+ }
328
+
329
+ .config-item label {
330
+ display: block;
331
+ margin-bottom: 5px;
332
+ font-weight: bold;
333
+ }
334
+
335
+ select {
336
+ width: 100%;
337
+ padding: 8px;
338
+ border: 1px solid #ddd;
339
+ border-radius: 4px;
340
+ }
341
+
342
+ .controls {
343
+ margin: 20px 0;
344
+ }
345
+
346
+ button {
347
+ padding: 10px 20px;
348
+ margin-right: 10px;
349
+ border: none;
350
+ border-radius: 4px;
351
+ background-color: #007bff;
352
+ color: white;
353
+ cursor: pointer;
354
+ }
355
+
356
+ button:disabled {
357
+ background-color: #ccc;
358
+ cursor: not-allowed;
359
+ }
360
+
361
+ .charts-container {
362
+ display: flex;
363
+ flex-direction: column;
364
+ gap: 20px;
365
+ margin-top: 20px;
366
+ }
367
+
368
+ #lossChart, #accuracyChart {
369
+ height: 400px;
370
+ width: 100%;
371
+ }
372
+ </style>
373
+ </body>
374
+ </html>