Create GreaterThan_MLP_V1.1_with_FailuresAnalysis.ipynb
Browse files
    	
        GreaterThan_MLP_V1.1_with_FailuresAnalysis.ipynb
    ADDED
    
    | 
         @@ -0,0 +1,467 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GreaterThan_MLP_V1.1_with_FailuresAnalysis.py
         
     | 
| 2 | 
         
            +
            """
         
     | 
| 3 | 
         
            +
            The objective of GreaterThan_MLP_V1.0.py is to establish a fundamental performance baseline 
         
     | 
| 4 | 
         
            +
            for a numerical comparison task using a deliberately simple Multi-Layer Perceptron (MLP). 
         
     | 
| 5 | 
         
            +
            It avoids all natural language processing techniques by treating the problem as a pure binary classification 
         
     | 
| 6 | 
         
            +
            on a fixed-size vector. The dataset consists of synthetically generated pairs of 
         
     | 
| 7 | 
         
            +
            two-digit decimal numbers (e.g., 10.00 and 09.21), 
         
     | 
| 8 | 
         
            +
            which are deconstructed and flattened into an 8-dimensional feature vector of their raw digits 
         
     | 
| 9 | 
         
            +
            ([1, 0, 0, 0,
         
     | 
| 10 | 
         
            +
            0, 9, 2, 1]). 
         
     | 
| 11 | 
         
            +
            The model is then trained to predict a single binary label (0 for left > right, 1 for right > left), 
         
     | 
| 12 | 
         
            +
            directly testing the MLP's capability to learn the hierarchical rules of numerical magnitude 
         
     | 
| 13 | 
         
            +
            from the positional values of the input digits alone.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            The MLP model's task is to learn the rules of numerical magnitude from raw digits alone, 
         
     | 
| 16 | 
         
            +
            treating the problem as a simple binary classification task. 
         
     | 
| 17 | 
         
            +
            It's designed for maximum clarity and serves as a fundamental baseline for this reasoning problem.
         
     | 
| 18 | 
         
            +
            The plan is clear: a simple MLP for binary classification. 
         
     | 
| 19 | 
         
            +
            The 8-dimensional input vector, constructed from the two 4-digit numbers, will be the focus. 
         
     | 
| 20 | 
         
            +
            The output will cleanly indicate which number is greater. Using on-the-fly data generation.
         
     | 
| 21 | 
         
            +
            The generate_mlp_data function produces the correct 8-dimensional input vectors and binary labels. 
         
     | 
| 22 | 
         
            +
            GreaterThan_MLP_V1.0.py presents a basic numerical comparison challenge using a rudimentary MLP as a baseline. 
         
     | 
| 23 | 
         
            +
            The core approach hinges on framing the task as a binary classification problem on a fixed-length feature vector. 
         
     | 
| 24 | 
         
            +
            Pairs of decimal numbers are converted into an 8-dimensional array of their digit values; 
         
     | 
| 25 | 
         
            +
            for instance, 10.00 and 09.21 are transformed to [1, 0, 0, 0, 0, 9, 2, 1]. 
         
     | 
| 26 | 
         
            +
            The model's training focuses on predicting whether one number is greater than another through a single binary label.
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            The MLP baseline model performs remarkably well, achieving over 99.9% accuracy in deciding "GreaterThan". 
         
     | 
| 29 | 
         
            +
            This indicates that the underlying logic of numerical comparison can be learned from raw digits by a simple neural network, 
         
     | 
| 30 | 
         
            +
            provided the input is structured as a fixed-size vector.
         
     | 
| 31 | 
         
            +
            However, even with high accuracy, failures still occur. Understanding why and on what data the model fails 
         
     | 
| 32 | 
         
            +
            is the next critical step in ML engineering. This is how we discover dataset biases, edge cases, and architectural weaknesses.
         
     | 
| 33 | 
         
            +
            Here is the modified script, GreaterThan_MLP_V1.1_with_FailuresAnalysis.py 
         
     | 
| 34 | 
         
            +
            It incorporates to automatically detect and log failures to a CSV file when accuracy is high, 
         
     | 
| 35 | 
         
            +
            creating a valuable dataset artifact for future analysis and the development of more robust models.
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            Here is the ouput of the first demonstration run in colab:
         
     | 
| 38 | 
         
            +
            Model initialized with 9473 parameters.
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            --- Starting Training ---
         
     | 
| 41 | 
         
            +
            Epoch [1/100], Train Loss: 0.4015, Train Acc: 82.67%, | Val Loss: 0.1690, Val Acc: 97.03%
         
     | 
| 42 | 
         
            +
            Epoch [2/100], Train Loss: 0.1743, Train Acc: 92.94%, | Val Loss: 0.0974, Val Acc: 98.04%
         
     | 
| 43 | 
         
            +
            Epoch [3/100], Train Loss: 0.1300, Train Acc: 94.54%, | Val Loss: 0.0741, Val Acc: 98.61%
         
     | 
| 44 | 
         
            +
            Epoch [4/100], Train Loss: 0.1112, Train Acc: 95.20%, | Val Loss: 0.0618, Val Acc: 98.96%
         
     | 
| 45 | 
         
            +
            Epoch [5/100], Train Loss: 0.1019, Train Acc: 95.61%, | Val Loss: 0.0565, Val Acc: 98.79%
         
     | 
| 46 | 
         
            +
            Epoch [6/100], Train Loss: 0.0926, Train Acc: 96.04%, | Val Loss: 0.0498, Val Acc: 99.10%
         
     | 
| 47 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 48 | 
         
            +
                -> Logged 607 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 49 | 
         
            +
                -> Logged 180 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 50 | 
         
            +
            Epoch [7/100], Train Loss: 0.0857, Train Acc: 96.33%, | Val Loss: 0.0456, Val Acc: 99.19%
         
     | 
| 51 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 52 | 
         
            +
                -> Logged 562 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 53 | 
         
            +
                -> Logged 161 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 54 | 
         
            +
            Epoch [8/100], Train Loss: 0.0827, Train Acc: 96.47%, | Val Loss: 0.0430, Val Acc: 99.14%
         
     | 
| 55 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 56 | 
         
            +
                -> Logged 538 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 57 | 
         
            +
                -> Logged 171 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 58 | 
         
            +
            Epoch [9/100], Train Loss: 0.0767, Train Acc: 96.73%, | Val Loss: 0.0398, Val Acc: 99.33%
         
     | 
| 59 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 60 | 
         
            +
                -> Logged 462 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 61 | 
         
            +
                -> Logged 133 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 62 | 
         
            +
            Epoch [10/100], Train Loss: 0.0727, Train Acc: 96.87%, | Val Loss: 0.0376, Val Acc: 99.33%
         
     | 
| 63 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 64 | 
         
            +
                -> Logged 457 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 65 | 
         
            +
                -> Logged 134 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 66 | 
         
            +
            Epoch [11/100], Train Loss: 0.0692, Train Acc: 97.04%, | Val Loss: 0.0380, Val Acc: 99.06%
         
     | 
| 67 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 68 | 
         
            +
                -> Logged 703 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 69 | 
         
            +
                -> Logged 189 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 70 | 
         
            +
            Epoch [12/100], Train Loss: 0.0665, Train Acc: 97.17%, | Val Loss: 0.0333, Val Acc: 99.42%
         
     | 
| 71 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 72 | 
         
            +
                -> Logged 365 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 73 | 
         
            +
                -> Logged 117 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 74 | 
         
            +
            Epoch [13/100], Train Loss: 0.0619, Train Acc: 97.36%, | Val Loss: 0.0316, Val Acc: 99.42%
         
     | 
| 75 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 76 | 
         
            +
                -> Logged 396 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 77 | 
         
            +
                -> Logged 115 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 78 | 
         
            +
            Epoch [14/100], Train Loss: 0.0599, Train Acc: 97.46%, | Val Loss: 0.0301, Val Acc: 99.41%
         
     | 
| 79 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 80 | 
         
            +
                -> Logged 397 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 81 | 
         
            +
                -> Logged 119 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 82 | 
         
            +
            Epoch [15/100], Train Loss: 0.0568, Train Acc: 97.63%, | Val Loss: 0.0282, Val Acc: 99.47%
         
     | 
| 83 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 84 | 
         
            +
                -> Logged 359 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 85 | 
         
            +
                -> Logged 107 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 86 | 
         
            +
            Epoch [16/100], Train Loss: 0.0550, Train Acc: 97.72%, | Val Loss: 0.0266, Val Acc: 99.53%
         
     | 
| 87 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 88 | 
         
            +
                -> Logged 331 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 89 | 
         
            +
                -> Logged 94 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 90 | 
         
            +
            Epoch [17/100], Train Loss: 0.0524, Train Acc: 97.80%, | Val Loss: 0.0256, Val Acc: 99.55%
         
     | 
| 91 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 92 | 
         
            +
                -> Logged 321 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 93 | 
         
            +
                -> Logged 91 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 94 | 
         
            +
            Epoch [18/100], Train Loss: 0.0504, Train Acc: 97.93%, | Val Loss: 0.0240, Val Acc: 99.56%
         
     | 
| 95 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 96 | 
         
            +
                -> Logged 290 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 97 | 
         
            +
                -> Logged 87 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 98 | 
         
            +
            Epoch [19/100], Train Loss: 0.0472, Train Acc: 98.04%, | Val Loss: 0.0228, Val Acc: 99.53%
         
     | 
| 99 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 100 | 
         
            +
                -> Logged 288 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 101 | 
         
            +
                -> Logged 93 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 102 | 
         
            +
            Epoch [20/100], Train Loss: 0.0447, Train Acc: 98.16%, | Val Loss: 0.0216, Val Acc: 99.61%
         
     | 
| 103 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 104 | 
         
            +
                -> Logged 289 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 105 | 
         
            +
                -> Logged 78 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 106 | 
         
            +
            Epoch [21/100], Train Loss: 0.0445, Train Acc: 98.12%, | Val Loss: 0.0201, Val Acc: 99.69%
         
     | 
| 107 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 108 | 
         
            +
                -> Logged 240 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 109 | 
         
            +
                -> Logged 63 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 110 | 
         
            +
            Epoch [22/100], Train Loss: 0.0412, Train Acc: 98.29%, | Val Loss: 0.0191, Val Acc: 99.65%
         
     | 
| 111 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 112 | 
         
            +
                -> Logged 227 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 113 | 
         
            +
                -> Logged 70 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 114 | 
         
            +
            Epoch [23/100], Train Loss: 0.0395, Train Acc: 98.35%, | Val Loss: 0.0181, Val Acc: 99.65%
         
     | 
| 115 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 116 | 
         
            +
                -> Logged 236 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 117 | 
         
            +
                -> Logged 70 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 118 | 
         
            +
            Epoch [24/100], Train Loss: 0.0373, Train Acc: 98.48%, | Val Loss: 0.0170, Val Acc: 99.71%
         
     | 
| 119 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 120 | 
         
            +
                -> Logged 209 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 121 | 
         
            +
                -> Logged 58 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 122 | 
         
            +
            Epoch [25/100], Train Loss: 0.0362, Train Acc: 98.53%, | Val Loss: 0.0164, Val Acc: 99.68%
         
     | 
| 123 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 124 | 
         
            +
                -> Logged 222 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 125 | 
         
            +
                -> Logged 64 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 126 | 
         
            +
            Epoch [26/100], Train Loss: 0.0345, Train Acc: 98.61%, | Val Loss: 0.0153, Val Acc: 99.73%
         
     | 
| 127 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 128 | 
         
            +
                -> Logged 199 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 129 | 
         
            +
                -> Logged 53 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 130 | 
         
            +
            Epoch [27/100], Train Loss: 0.0317, Train Acc: 98.74%, | Val Loss: 0.0149, Val Acc: 99.61%
         
     | 
| 131 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 132 | 
         
            +
                -> Logged 253 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 133 | 
         
            +
                -> Logged 78 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 134 | 
         
            +
            Epoch [28/100], Train Loss: 0.0302, Train Acc: 98.80%, | Val Loss: 0.0134, Val Acc: 99.80%
         
     | 
| 135 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 136 | 
         
            +
                -> Logged 162 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 137 | 
         
            +
                -> Logged 40 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 138 | 
         
            +
            Epoch [29/100], Train Loss: 0.0299, Train Acc: 98.80%, | Val Loss: 0.0127, Val Acc: 99.77%
         
     | 
| 139 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 140 | 
         
            +
                -> Logged 163 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 141 | 
         
            +
                -> Logged 46 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 142 | 
         
            +
            Epoch [30/100], Train Loss: 0.0261, Train Acc: 98.98%, | Val Loss: 0.0125, Val Acc: 99.68%
         
     | 
| 143 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 144 | 
         
            +
                -> Logged 240 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 145 | 
         
            +
                -> Logged 64 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 146 | 
         
            +
            Epoch [31/100], Train Loss: 0.0251, Train Acc: 99.05%, | Val Loss: 0.0110, Val Acc: 99.84%
         
     | 
| 147 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 148 | 
         
            +
                -> Logged 135 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 149 | 
         
            +
                -> Logged 32 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 150 | 
         
            +
            Epoch [32/100], Train Loss: 0.0246, Train Acc: 99.01%, | Val Loss: 0.0108, Val Acc: 99.78%
         
     | 
| 151 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 152 | 
         
            +
                -> Logged 167 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 153 | 
         
            +
                -> Logged 43 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 154 | 
         
            +
            Epoch [33/100], Train Loss: 0.0237, Train Acc: 99.07%, | Val Loss: 0.0103, Val Acc: 99.83%
         
     | 
| 155 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 156 | 
         
            +
                -> Logged 121 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 157 | 
         
            +
                -> Logged 34 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 158 | 
         
            +
            Epoch [34/100], Train Loss: 0.0224, Train Acc: 99.14%, | Val Loss: 0.0096, Val Acc: 99.86%
         
     | 
| 159 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 160 | 
         
            +
                -> Logged 127 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 161 | 
         
            +
                -> Logged 29 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 162 | 
         
            +
            Epoch [35/100], Train Loss: 0.0220, Train Acc: 99.15%, | Val Loss: 0.0092, Val Acc: 99.89%
         
     | 
| 163 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 164 | 
         
            +
                -> Logged 100 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 165 | 
         
            +
                -> Logged 23 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 166 | 
         
            +
            Epoch [36/100], Train Loss: 0.0204, Train Acc: 99.22%, | Val Loss: 0.0090, Val Acc: 99.83%
         
     | 
| 167 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 168 | 
         
            +
                -> Logged 126 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 169 | 
         
            +
                -> Logged 34 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 170 | 
         
            +
            Epoch [37/100], Train Loss: 0.0194, Train Acc: 99.25%, | Val Loss: 0.0083, Val Acc: 99.89%
         
     | 
| 171 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 172 | 
         
            +
                -> Logged 93 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 173 | 
         
            +
                -> Logged 23 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 174 | 
         
            +
            Epoch [38/100], Train Loss: 0.0191, Train Acc: 99.25%, | Val Loss: 0.0081, Val Acc: 99.85%
         
     | 
| 175 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 176 | 
         
            +
                -> Logged 110 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 177 | 
         
            +
                -> Logged 30 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 178 | 
         
            +
            Epoch [39/100], Train Loss: 0.0182, Train Acc: 99.31%, | Val Loss: 0.0076, Val Acc: 99.89%
         
     | 
| 179 | 
         
            +
              -> High accuracy detected. Scanning for failures...
         
     | 
| 180 | 
         
            +
                -> Logged 74 failures for 'train' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 181 | 
         
            +
                -> Logged 22 failures for 'val' split to GreaterThan_MLP_V1.1_FailureAnalysis_failed_samples.csv
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
            """
         
     | 
| 184 | 
         
            +
            # REFACTORING MISSION:
         
     | 
| 185 | 
         
            +
            # This script objective is to perform
         
     | 
| 186 | 
         
            +
            # automated failure analysis. When training or validation accuracy surpasses
         
     | 
| 187 | 
         
            +
            # a 99% threshold, the script will automatically log the specific samples
         
     | 
| 188 | 
         
            +
            # that the model failed on. These failures are appended to a CSV file for
         
     | 
| 189 | 
         
            +
            # later inspection, which is invaluable for creating targeted test sets or
         
     | 
| 190 | 
         
            +
            # improving the training data.
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
            import torch
         
     | 
| 193 | 
         
            +
            import torch.nn as nn
         
     | 
| 194 | 
         
            +
            from torch.utils.data import TensorDataset, DataLoader
         
     | 
| 195 | 
         
            +
            import random
         
     | 
| 196 | 
         
            +
            import numpy as np
         
     | 
| 197 | 
         
            +
            import zipfile
         
     | 
| 198 | 
         
            +
            import os
         
     | 
| 199 | 
         
            +
            import sys
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
            # ==============================================================================
         
     | 
| 202 | 
         
            +
            # Part 0: Configuration
         
     | 
| 203 | 
         
            +
            # ==============================================================================
         
     | 
| 204 | 
         
            +
            class Config:
         
     | 
| 205 | 
         
            +
                # --- Data ---
         
     | 
| 206 | 
         
            +
                num_samples = 100000 # Increased dataset size for more robust training
         
     | 
| 207 | 
         
            +
                train_split = 0.8
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                # --- Model Architecture ---
         
     | 
| 210 | 
         
            +
                input_size = 8
         
     | 
| 211 | 
         
            +
                hidden_size_1 = 128
         
     | 
| 212 | 
         
            +
                hidden_size_2 = 64
         
     | 
| 213 | 
         
            +
                output_size = 1
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                # --- Training ---
         
     | 
| 216 | 
         
            +
                learning_rate = 1e-4 # Slightly lower LR for finer tuning
         
     | 
| 217 | 
         
            +
                batch_size = 256
         
     | 
| 218 | 
         
            +
                epochs = 100 # Reduced epochs to 20 as convergence should be faster with more data
         
     | 
| 219 | 
         
            +
                weight_decay = 1e-4
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                # --- NEW: Failure Analysis ---
         
     | 
| 222 | 
         
            +
                # The accuracy threshold to trigger logging of failed samples.
         
     | 
| 223 | 
         
            +
                failure_log_threshold = 99.0
         
     | 
| 224 | 
         
            +
                # The name of the script, used for the output CSV file.
         
     | 
| 225 | 
         
            +
                script_name = "GreaterThan_MLP_V1.1_FailureAnalysis"
         
     | 
| 226 | 
         
            +
                failure_log_filename = f"{script_name}_failed_samples.csv"
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                # --- Device ---
         
     | 
| 229 | 
         
            +
                device = 'cuda' if torch.cuda.is_available() else 'cpu'
         
     | 
| 230 | 
         
            +
                print(f"Using device: {device}")
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
            config = Config()
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
            # For reproducibility
         
     | 
| 235 | 
         
            +
            torch.manual_seed(1337)
         
     | 
| 236 | 
         
            +
            random.seed(1337)
         
     | 
| 237 | 
         
            +
            np.random.seed(1337)
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
            # ==============================================================================
         
     | 
| 240 | 
         
            +
            # Part 1: Colab Utility & Data Generation (Unchanged from V1.0)
         
     | 
| 241 | 
         
            +
            # ==============================================================================
         
     | 
| 242 | 
         
            +
            def is_in_colab():
         
     | 
| 243 | 
         
            +
                """Checks if the script is running in a Google Colab environment."""
         
     | 
| 244 | 
         
            +
                try:
         
     | 
| 245 | 
         
            +
                    import google.colab
         
     | 
| 246 | 
         
            +
                    return True
         
     | 
| 247 | 
         
            +
                except ImportError:
         
     | 
| 248 | 
         
            +
                    return False
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
            def generate_mlp_data(num_samples):
         
     | 
| 251 | 
         
            +
                """Generates synthetic data for the MLP."""
         
     | 
| 252 | 
         
            +
                print(f"Generating {num_samples} data points...")
         
     | 
| 253 | 
         
            +
                features, labels = [], []
         
     | 
| 254 | 
         
            +
                for _ in range(num_samples):
         
     | 
| 255 | 
         
            +
                    a = round(random.uniform(0, 99.99), 2)
         
     | 
| 256 | 
         
            +
                    b = round(random.uniform(0, 99.99), 2)
         
     | 
| 257 | 
         
            +
                    while a == b:
         
     | 
| 258 | 
         
            +
                        b = round(random.uniform(0, 99.99), 2)
         
     | 
| 259 | 
         
            +
                    a_str, b_str = f"{a:05.2f}", f"{b:05.2f}"
         
     | 
| 260 | 
         
            +
                    a_digits, b_digits = [int(d) for d in a_str if d.isdigit()], [int(d) for d in b_str if d.isdigit()]
         
     | 
| 261 | 
         
            +
                    features.append(a_digits + b_digits)
         
     | 
| 262 | 
         
            +
                    labels.append(0 if a > b else 1)
         
     | 
| 263 | 
         
            +
                X = torch.tensor(features, dtype=torch.float32)
         
     | 
| 264 | 
         
            +
                y = torch.tensor(labels, dtype=torch.float32).unsqueeze(1)
         
     | 
| 265 | 
         
            +
                print("Data generation complete.")
         
     | 
| 266 | 
         
            +
                return X, y
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
            # ==============================================================================
         
     | 
| 269 | 
         
            +
            # Part 2: Model Architecture (Unchanged from V1.0)
         
     | 
| 270 | 
         
            +
            # ==============================================================================
         
     | 
| 271 | 
         
            +
            class SimpleMLP(nn.Module):
         
     | 
| 272 | 
         
            +
                def __init__(self, input_size, hidden_size_1, hidden_size_2, output_size):
         
     | 
| 273 | 
         
            +
                    super().__init__()
         
     | 
| 274 | 
         
            +
                    self.net = nn.Sequential(
         
     | 
| 275 | 
         
            +
                        nn.Linear(input_size, hidden_size_1), nn.ReLU(), nn.Dropout(0.2),
         
     | 
| 276 | 
         
            +
                        nn.Linear(hidden_size_1, hidden_size_2), nn.ReLU(), nn.Dropout(0.2),
         
     | 
| 277 | 
         
            +
                        nn.Linear(hidden_size_2, output_size)
         
     | 
| 278 | 
         
            +
                    )
         
     | 
| 279 | 
         
            +
                def forward(self, x):
         
     | 
| 280 | 
         
            +
                    return self.net(x)
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
            # ==============================================================================
         
     | 
| 283 | 
         
            +
            # Part 3: MODIFIED Training and Evaluation Loop
         
     | 
| 284 | 
         
            +
            # ==============================================================================
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
            def log_failures(model, loader, split, epoch, filename, device):
         
     | 
| 287 | 
         
            +
                """
         
     | 
| 288 | 
         
            +
                NEW: Iterates through a data loader, finds incorrect predictions,
         
     | 
| 289 | 
         
            +
                and appends them to the specified CSV log file.
         
     | 
| 290 | 
         
            +
                """
         
     | 
| 291 | 
         
            +
                model.eval()
         
     | 
| 292 | 
         
            +
                failures_found = 0
         
     | 
| 293 | 
         
            +
                with torch.no_grad():
         
     | 
| 294 | 
         
            +
                    for inputs, labels in loader:
         
     | 
| 295 | 
         
            +
                        inputs, labels = inputs.to(device), labels.to(device)
         
     | 
| 296 | 
         
            +
                        outputs = model(inputs)
         
     | 
| 297 | 
         
            +
                        predicted = torch.round(torch.sigmoid(outputs))
         
     | 
| 298 | 
         
            +
                        mismatch_indices = (predicted != labels).squeeze()
         
     | 
| 299 | 
         
            +
                        
         
     | 
| 300 | 
         
            +
                        if mismatch_indices.any():
         
     | 
| 301 | 
         
            +
                            failed_inputs = inputs[mismatch_indices]
         
     | 
| 302 | 
         
            +
                            failed_true_labels = labels[mismatch_indices]
         
     | 
| 303 | 
         
            +
                            failed_pred_labels = predicted[mismatch_indices]
         
     | 
| 304 | 
         
            +
                            
         
     | 
| 305 | 
         
            +
                            with open(filename, 'a') as f:
         
     | 
| 306 | 
         
            +
                                for i in range(failed_inputs.size(0)):
         
     | 
| 307 | 
         
            +
                                    # Format the input vector back into a readable string
         
     | 
| 308 | 
         
            +
                                    input_vec_int = failed_inputs[i].cpu().numpy().astype(int)
         
     | 
| 309 | 
         
            +
                                    num1_str = f"{input_vec_int[0]}{input_vec_int[1]}.{input_vec_int[2]}{input_vec_int[3]}"
         
     | 
| 310 | 
         
            +
                                    num2_str = f"{input_vec_int[4]}{input_vec_int[5]}.{input_vec_int[6]}{input_vec_int[7]}"
         
     | 
| 311 | 
         
            +
                                    
         
     | 
| 312 | 
         
            +
                                    true_label = int(failed_true_labels[i].item())
         
     | 
| 313 | 
         
            +
                                    pred_label = int(failed_pred_labels[i].item())
         
     | 
| 314 | 
         
            +
                                    
         
     | 
| 315 | 
         
            +
                                    f.write(f"{epoch},{split},{num1_str},{num2_str},{true_label},{pred_label}\n")
         
     | 
| 316 | 
         
            +
                                    failures_found += 1
         
     | 
| 317 | 
         
            +
                if failures_found > 0:
         
     | 
| 318 | 
         
            +
                    print(f"    -> Logged {failures_found} failures for '{split}' split to {filename}")
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
            def train_model(model, train_loader, val_loader, optimizer, criterion, epochs, config):
         
     | 
| 321 | 
         
            +
                """The main training loop, now with failure logging."""
         
     | 
| 322 | 
         
            +
                print("\n--- Starting Training ---")
         
     | 
| 323 | 
         
            +
                
         
     | 
| 324 | 
         
            +
                # NEW: Initialize the failure log file with a header
         
     | 
| 325 | 
         
            +
                with open(config.failure_log_filename, 'w') as f:
         
     | 
| 326 | 
         
            +
                    f.write("epoch,split,num1,num2,true_label(0:L>R;1:R>L),predicted_label\n")
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                for epoch in range(epochs):
         
     | 
| 329 | 
         
            +
                    model.train()
         
     | 
| 330 | 
         
            +
                    train_loss, train_correct, train_total = 0, 0, 0
         
     | 
| 331 | 
         
            +
                    for inputs, labels in train_loader:
         
     | 
| 332 | 
         
            +
                        inputs, labels = inputs.to(config.device), labels.to(config.device)
         
     | 
| 333 | 
         
            +
                        outputs = model(inputs)
         
     | 
| 334 | 
         
            +
                        loss = criterion(outputs, labels)
         
     | 
| 335 | 
         
            +
                        optimizer.zero_grad()
         
     | 
| 336 | 
         
            +
                        loss.backward()
         
     | 
| 337 | 
         
            +
                        optimizer.step()
         
     | 
| 338 | 
         
            +
                        train_loss += loss.item()
         
     | 
| 339 | 
         
            +
                        predicted = torch.round(torch.sigmoid(outputs))
         
     | 
| 340 | 
         
            +
                        train_total += labels.size(0)
         
     | 
| 341 | 
         
            +
                        train_correct += (predicted == labels).sum().item()
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                    train_avg_loss = train_loss / len(train_loader)
         
     | 
| 344 | 
         
            +
                    train_accuracy = 100 * train_correct / train_total
         
     | 
| 345 | 
         
            +
                    
         
     | 
| 346 | 
         
            +
                    model.eval()
         
     | 
| 347 | 
         
            +
                    val_loss, val_correct, val_total = 0, 0, 0
         
     | 
| 348 | 
         
            +
                    with torch.no_grad():
         
     | 
| 349 | 
         
            +
                        for inputs, labels in val_loader:
         
     | 
| 350 | 
         
            +
                            inputs, labels = inputs.to(config.device), labels.to(config.device)
         
     | 
| 351 | 
         
            +
                            outputs = model(inputs)
         
     | 
| 352 | 
         
            +
                            val_loss += criterion(outputs, labels).item()
         
     | 
| 353 | 
         
            +
                            predicted = torch.round(torch.sigmoid(outputs))
         
     | 
| 354 | 
         
            +
                            val_total += labels.size(0)
         
     | 
| 355 | 
         
            +
                            val_correct += (predicted == labels).sum().item()
         
     | 
| 356 | 
         
            +
                    
         
     | 
| 357 | 
         
            +
                    val_avg_loss = val_loss / len(val_loader)
         
     | 
| 358 | 
         
            +
                    val_accuracy = 100 * val_correct / val_total
         
     | 
| 359 | 
         
            +
                    
         
     | 
| 360 | 
         
            +
                    print(f"Epoch [{epoch+1}/{epochs}], "
         
     | 
| 361 | 
         
            +
                          f"Train Loss: {train_avg_loss:.4f}, Train Acc: {train_accuracy:.2f}%, | "
         
     | 
| 362 | 
         
            +
                          f"Val Loss: {val_avg_loss:.4f}, Val Acc: {val_accuracy:.2f}%")
         
     | 
| 363 | 
         
            +
                          
         
     | 
| 364 | 
         
            +
                    # --- NEW: Conditional Failure Logging ---
         
     | 
| 365 | 
         
            +
                    if train_accuracy > config.failure_log_threshold or val_accuracy > config.failure_log_threshold:
         
     | 
| 366 | 
         
            +
                        print(f"  -> High accuracy detected. Scanning for failures...")
         
     | 
| 367 | 
         
            +
                        # Re-iterate over loaders to find and log the specific failures for this epoch
         
     | 
| 368 | 
         
            +
                        log_failures(model, train_loader, 'train', epoch + 1, config.failure_log_filename, config.device)
         
     | 
| 369 | 
         
            +
                        log_failures(model, val_loader, 'val', epoch + 1, config.failure_log_filename, config.device)
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                print("--- Training Finished ---")
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
            # ==============================================================================
         
     | 
| 374 | 
         
            +
            # Part 4: MODIFIED Final Test Suite
         
     | 
| 375 | 
         
            +
            # ==============================================================================
         
     | 
| 376 | 
         
            +
            def run_final_tests(model, config):
         
     | 
| 377 | 
         
            +
                """Runs the trained model against a suite of hardcoded test cases and logs failures."""
         
     | 
| 378 | 
         
            +
                print("\n--- Running Final Test Suite ---")
         
     | 
| 379 | 
         
            +
                
         
     | 
| 380 | 
         
            +
                test_cases = [
         
     | 
| 381 | 
         
            +
                    ("Simple Greater", 10.00, 9.21), ("Simple Lesser", 5.50, 50.50),
         
     | 
| 382 | 
         
            +
                    ("Decimal Greater", 54.13, 54.12), ("Decimal Lesser", 99.98, 99.99),
         
     | 
| 383 | 
         
            +
                    ("Edge Case: Large Difference", 0.01, 99.99), ("Edge Case: Zero", 0.00, 5.00),
         
     | 
| 384 | 
         
            +
                    ("Tricky: Same Integer Part", 25.80, 25.79), ("Tricky: Crossover", 49.99, 50.00),
         
     | 
| 385 | 
         
            +
                ]
         
     | 
| 386 | 
         
            +
                
         
     | 
| 387 | 
         
            +
                results_log = "--- MLP Test Suite Results ---\n\n"
         
     | 
| 388 | 
         
            +
                correct_tests = 0
         
     | 
| 389 | 
         
            +
                
         
     | 
| 390 | 
         
            +
                model.eval()
         
     | 
| 391 | 
         
            +
                with torch.no_grad():
         
     | 
| 392 | 
         
            +
                    for description, a, b in test_cases:
         
     | 
| 393 | 
         
            +
                        a_str, b_str = f"{a:05.2f}", f"{b:05.2f}"
         
     | 
| 394 | 
         
            +
                        a_digits, b_digits = [int(d) for d in a_str if d.isdigit()], [int(d) for d in b_str if d.isdigit()]
         
     | 
| 395 | 
         
            +
                        feature_vector = torch.tensor(a_digits + b_digits, dtype=torch.float32).to(config.device)
         
     | 
| 396 | 
         
            +
                        
         
     | 
| 397 | 
         
            +
                        output = model(feature_vector)
         
     | 
| 398 | 
         
            +
                        predicted_class = 1 if torch.sigmoid(output).item() > 0.5 else 0
         
     | 
| 399 | 
         
            +
                        ground_truth_class = 0 if a > b else 1
         
     | 
| 400 | 
         
            +
                        
         
     | 
| 401 | 
         
            +
                        result = "CORRECT"
         
     | 
| 402 | 
         
            +
                        if predicted_class != ground_truth_class:
         
     | 
| 403 | 
         
            +
                            result = "INCORRECT"
         
     | 
| 404 | 
         
            +
                            # --- NEW: Log failure to CSV ---
         
     | 
| 405 | 
         
            +
                            with open(config.failure_log_filename, 'a') as f:
         
     | 
| 406 | 
         
            +
                                f.write(f"final_test,{description.replace(',',';')},{a_str},{b_str},{ground_truth_class},{predicted_class}\n")
         
     | 
| 407 | 
         
            +
                        else:
         
     | 
| 408 | 
         
            +
                            correct_tests += 1
         
     | 
| 409 | 
         
            +
                        
         
     | 
| 410 | 
         
            +
                        predicted_winner = "Left" if predicted_class == 0 else "Right"
         
     | 
| 411 | 
         
            +
                        log_line = (f"Test: '{description}' | {a_str} vs {b_str}\n"
         
     | 
| 412 | 
         
            +
                                    f"  -> Model says: {predicted_winner} is greater\n"
         
     | 
| 413 | 
         
            +
                                    f"  -> Result: {result}\n" + "-"*30 + "\n")
         
     | 
| 414 | 
         
            +
                        print(log_line)
         
     | 
| 415 | 
         
            +
                        results_log += log_line
         
     | 
| 416 | 
         
            +
             
     | 
| 417 | 
         
            +
                final_accuracy = 100 * correct_tests / len(test_cases)
         
     | 
| 418 | 
         
            +
                summary = f"\nFinal Test Accuracy: {final_accuracy:.2f}% ({correct_tests}/{len(test_cases)} correct)\n"
         
     | 
| 419 | 
         
            +
                print(summary)
         
     | 
| 420 | 
         
            +
                results_log += summary
         
     | 
| 421 | 
         
            +
                
         
     | 
| 422 | 
         
            +
                return results_log
         
     | 
| 423 | 
         
            +
             
     | 
| 424 | 
         
            +
            # ==============================================================================
         
     | 
| 425 | 
         
            +
            # Main Execution Block
         
     | 
| 426 | 
         
            +
            # ==============================================================================
         
     | 
| 427 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 428 | 
         
            +
                X, y = generate_mlp_data(config.num_samples)
         
     | 
| 429 | 
         
            +
                
         
     | 
| 430 | 
         
            +
                dataset = TensorDataset(X, y)
         
     | 
| 431 | 
         
            +
                train_size = int(config.train_split * len(dataset))
         
     | 
| 432 | 
         
            +
                val_size = len(dataset) - train_size
         
     | 
| 433 | 
         
            +
                train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
         
     | 
| 434 | 
         
            +
                train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
         
     | 
| 435 | 
         
            +
                val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
         
     | 
| 436 | 
         
            +
                
         
     | 
| 437 | 
         
            +
                model = SimpleMLP(config.input_size, config.hidden_size_1, config.hidden_size_2, config.output_size).to(config.device)
         
     | 
| 438 | 
         
            +
                print(f"\nModel initialized with {sum(p.numel() for p in model.parameters())} parameters.")
         
     | 
| 439 | 
         
            +
                
         
     | 
| 440 | 
         
            +
                optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
         
     | 
| 441 | 
         
            +
                criterion = nn.BCEWithLogitsLoss()
         
     | 
| 442 | 
         
            +
                
         
     | 
| 443 | 
         
            +
                train_model(model, train_loader, val_loader, optimizer, criterion, config.epochs, config)
         
     | 
| 444 | 
         
            +
                
         
     | 
| 445 | 
         
            +
                test_results = run_final_tests(model, config)
         
     | 
| 446 | 
         
            +
                
         
     | 
| 447 | 
         
            +
                if is_in_colab():
         
     | 
| 448 | 
         
            +
                    print(f"\nDetected Google Colab environment. Zipping and downloading results...")
         
     | 
| 449 | 
         
            +
                    results_filename = f"{config.script_name}_test_summary.txt"
         
     | 
| 450 | 
         
            +
                    with open(results_filename, "w") as f:
         
     | 
| 451 | 
         
            +
                        f.write(test_results)
         
     | 
| 452 | 
         
            +
                        
         
     | 
| 453 | 
         
            +
                    zip_filename = f"{config.script_name}_outputs.zip"
         
     | 
| 454 | 
         
            +
                    with zipfile.ZipFile(zip_filename, 'w') as zipf:
         
     | 
| 455 | 
         
            +
                        zipf.write(results_filename)
         
     | 
| 456 | 
         
            +
                        # Also include the new failure log in the zip file
         
     | 
| 457 | 
         
            +
                        if os.path.exists(config.failure_log_filename):
         
     | 
| 458 | 
         
            +
                            zipf.write(config.failure_log_filename)
         
     | 
| 459 | 
         
            +
                        
         
     | 
| 460 | 
         
            +
                    try:
         
     | 
| 461 | 
         
            +
                        from google.colab import files
         
     | 
| 462 | 
         
            +
                        files.download(zip_filename)
         
     | 
| 463 | 
         
            +
                        print(f"Downloaded {zip_filename} successfully.")
         
     | 
| 464 | 
         
            +
                    except Exception as e:
         
     | 
| 465 | 
         
            +
                        print(f"Could not initiate download. Error: {e}")
         
     | 
| 466 | 
         
            +
                else:
         
     | 
| 467 | 
         
            +
                    print(f"\nNot running in Colab. Test results printed above. Failures logged to {config.failure_log_filename}")
         
     |