File size: 3,556 Bytes
4c1a791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
let ws;

function initializeComparisonCharts() {
    const lossData = [{
        name: 'Model 1 Loss',
        x: [],
        y: [],
        type: 'scatter'
    }, {
        name: 'Model 2 Loss',
        x: [],
        y: [],
        type: 'scatter'
    }];

    const accuracyData = [{
        name: 'Model 1 Accuracy',
        x: [],
        y: [],
        type: 'scatter'
    }, {
        name: 'Model 2 Accuracy',
        x: [],
        y: [],
        type: 'scatter'
    }];

    Plotly.newPlot('comparison-loss-plot', lossData, {
        title: 'Loss Comparison',
        xaxis: { title: 'Iterations' },
        yaxis: { title: 'Loss' }
    });

    Plotly.newPlot('comparison-accuracy-plot', accuracyData, {
        title: 'Accuracy Comparison',
        xaxis: { title: 'Iterations' },
        yaxis: { title: 'Accuracy (%)' }
    });
}

async function compareModels() {
    const config = {
        model1: {
            kernels: [
                parseInt(document.getElementById('model1_kernel1').value),
                parseInt(document.getElementById('model1_kernel2').value),
                parseInt(document.getElementById('model1_kernel3').value)
            ],
            optimizer: document.getElementById('model1_optimizer').value,
            batch_size: parseInt(document.getElementById('model1_batch_size').value),
            epochs: parseInt(document.getElementById('model1_epochs').value)
        },
        model2: {
            kernels: [
                parseInt(document.getElementById('model2_kernel1').value),
                parseInt(document.getElementById('model2_kernel2').value),
                parseInt(document.getElementById('model2_kernel3').value)
            ],
            optimizer: document.getElementById('model2_optimizer').value,
            batch_size: parseInt(document.getElementById('model2_batch_size').value),
            epochs: parseInt(document.getElementById('model2_epochs').value)
        }
    };

    // Show comparison progress section
    document.getElementById('comparison-progress').classList.remove('hidden');
    initializeComparisonCharts();

    try {
        const response = await fetch('/api/train_compare', {
            method: 'POST',
            headers: {
                'Content-Type': 'application/json',
            },
            body: JSON.stringify(config)
        });
        const data = await response.json();
        
        if (data.status === 'success') {
            displayComparisonResults(data);
            alert('Model comparison completed successfully!');
        }
    } catch (error) {
        console.error('Error:', error);
        alert('Error during model comparison. Please check console for details.');
    }
}

function displayComparisonResults(data) {
    const logsDiv = document.getElementById('comparison-logs');
    logsDiv.innerHTML = `
        <div class="comparison-model">
            <h4>Model 1</h4>
            <p>Final Loss: ${data.model1_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
            <p>Final Accuracy: ${data.model1_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
            <p>Model Name: ${data.model1_results.model_name}</p>
        </div>
        <div class="comparison-model">
            <h4>Model 2</h4>
            <p>Final Loss: ${data.model2_results.history.train_loss.slice(-1)[0].toFixed(4)}</p>
            <p>Final Accuracy: ${data.model2_results.history.train_acc.slice(-1)[0].toFixed(2)}%</p>
            <p>Model Name: ${data.model2_results.model_name}</p>
        </div>
    `;
}