ribesstefano commited on
Commit
74a86c6
·
1 Parent(s): 9692a92

Improved plotting code

Browse files
Files changed (2) hide show
  1. README.md +4 -0
  2. src/plot_experiment_results.py +184 -0
README.md CHANGED
@@ -40,3 +40,7 @@ print(f'The given PROTAC is: {"active" if active_protac else "inactive"}')
40
 
41
  > If you're coming from my [thesis repo](https://github.com/ribesstefano/Machine-Learning-for-Predicting-Targeted-Protein-Degradation), I just wanted to create a separate and "less generic" repo for fast prototyping new ideas.
42
  > Stefano.
 
 
 
 
 
40
 
41
  > If you're coming from my [thesis repo](https://github.com/ribesstefano/Machine-Learning-for-Predicting-Targeted-Protein-Degradation), I just wanted to create a separate and "less generic" repo for fast prototyping new ideas.
42
  > Stefano.
43
+
44
+
45
+
46
+ > Why haven't you trained on more (i.e., the whole) data? We did, and we might just need _way_ more data to get better results...
src/plot_experiment_results.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5
+
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ import pandas as pd
9
+ import numpy as np
10
+
11
+
12
+ palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
13
+
14
+
15
+ def plot_metrics(df, title):
16
+ # Clean the data
17
+ df = df.dropna(how='all', axis=1)
18
+
19
+ # Convert all columns to numeric, setting errors='coerce' to handle non-numeric data
20
+ df = df.apply(pd.to_numeric, errors='coerce')
21
+
22
+ # Group by 'epoch' and aggregate by mean
23
+ epoch_data = df.groupby('epoch').mean()
24
+
25
+ fig, ax = plt.subplots(3, 1, figsize=(10, 15))
26
+
27
+ # Plot training loss
28
+ ax[0].plot(epoch_data.index, epoch_data['train_loss_epoch'], label='Training Loss')
29
+ ax[0].plot(epoch_data.index, epoch_data['test_loss'], label='Test Loss', linestyle='--')
30
+ ax[0].set_ylabel('Loss')
31
+ ax[0].legend(loc='lower right')
32
+ ax[0].grid(axis='both', alpha=0.5)
33
+
34
+ # Plot training accuracy
35
+ ax[1].plot(epoch_data.index, epoch_data['train_acc_epoch'], label='Training Accuracy')
36
+ ax[1].plot(epoch_data.index, epoch_data['test_acc'], label='Test Accuracy', linestyle='--')
37
+ ax[1].set_ylabel('Accuracy')
38
+ ax[1].legend(loc='lower right')
39
+ ax[1].grid(axis='both', alpha=0.5)
40
+
41
+ # Plot training ROC-AUC
42
+ ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
43
+ ax[2].plot(epoch_data.index, epoch_data['test_roc_auc'], label='Test ROC-AUC', linestyle='--')
44
+ ax[2].set_ylabel('ROC-AUC')
45
+ ax[2].legend(loc='lower right')
46
+ ax[2].grid(axis='both', alpha=0.5)
47
+
48
+ # Set x-axis label
49
+ ax[2].set_xlabel('Epoch')
50
+
51
+ plt.title(title)
52
+ plt.tight_layout()
53
+ plt.savefig(f'plots/{title}_metrics.pdf', bbox_inches='tight')
54
+
55
+
56
+ def plot_report(df_cv, df_test, title=None):
57
+
58
+ # Extract and prepare CV data
59
+ cv_data = df_cv[['model_type', 'fold', 'val_acc', 'val_roc_auc', 'test_acc', 'test_roc_auc', 'split_type']]
60
+ cv_data = cv_data.melt(id_vars=['model_type', 'fold', 'split_type'], var_name='Metric', value_name='Score')
61
+ cv_data['Metric'] = cv_data['Metric'].replace({
62
+ 'val_acc': 'Validation Accuracy',
63
+ 'val_roc_auc': 'Validation ROC AUC',
64
+ 'test_acc': 'Test Accuracy',
65
+ 'test_roc_auc': 'Test ROC AUC'
66
+ })
67
+ cv_data['Stage'] = cv_data['Metric'].apply(lambda x: 'Validation' if 'Val' in x else 'Test')
68
+
69
+ # Extract and prepare test data
70
+ test_data = df_test[['model_type', 'test_acc', 'test_roc_auc', 'split_type']]
71
+ test_data = test_data.melt(id_vars=['model_type', 'split_type'], var_name='Metric', value_name='Score')
72
+ test_data['Metric'] = test_data['Metric'].replace({
73
+ 'test_acc': 'Test Accuracy',
74
+ 'test_roc_auc': 'Test ROC AUC'
75
+ })
76
+ test_data['Stage'] = 'Test'
77
+
78
+ # Combine CV and test data
79
+ combined_data = pd.concat([cv_data, test_data], ignore_index=True)
80
+
81
+ # Rename 'split_type' values according to a predefined map for clarity
82
+ group2name = {
83
+ 'random': 'Standard Split',
84
+ 'uniprot': 'Target Split',
85
+ 'tanimoto': 'Similarity Split',
86
+ }
87
+ combined_data['Split Type'] = combined_data['split_type'].map(group2name)
88
+
89
+ # Add dummy model data
90
+ dummy_val_acc = []
91
+ dummy_test_acc = []
92
+ for i, group in enumerate(group2name.keys()):
93
+ # Get the majority class in group_df
94
+ group_df = df_cv[df_cv['split_type'] == group]
95
+ major_col = 'inactive' if group_df['val_inactive_perc'].mean() > 0.5 else 'active'
96
+ dummy_val_acc.append(group_df[f'val_{major_col}_perc'].mean())
97
+
98
+ group_df = df_test[df_test['split_type'] == group]
99
+ major_col = 'inactive' if group_df['test_inactive_perc'].mean() > 0.5 else 'active'
100
+ dummy_test_acc.append(group_df[f'test_{major_col}_perc'].mean())
101
+
102
+ dummy_scores = []
103
+ metrics = ['Validation Accuracy', 'Validation ROC AUC', 'Test Accuracy', 'Test ROC AUC']
104
+ for i in range(len(dummy_val_acc)):
105
+ for metric, score in zip(metrics, [dummy_val_acc[i], 0.5, dummy_test_acc[i], 0.5]):
106
+ dummy_scores.append({
107
+ 'Experiment': i,
108
+ 'Metric': metric,
109
+ 'Score': score,
110
+ 'Split Type': 'Dummy model',
111
+ })
112
+ dummy_model = pd.DataFrame(dummy_scores)
113
+ combined_data = pd.concat([combined_data, dummy_model], ignore_index=True)
114
+
115
+ # Plotting
116
+ plt.figure(figsize=(12, 6))
117
+ sns.barplot(data=combined_data, x='Metric', y='Score', hue='Split Type', errorbar='sd', palette=palette)
118
+ plt.title('')
119
+ plt.ylabel('')
120
+ plt.xlabel('')
121
+ plt.ylim(0, 1.0) # Assuming scores are normalized between 0 and 1
122
+ plt.grid(axis='y', alpha=0.5, linewidth=0.5)
123
+
124
+ # Make the y-axis as percentage
125
+ plt.gca().yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
126
+ # Plot the legend below the x-axis, outside the plot, and divided in two columns
127
+ plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.08), ncol=4)
128
+
129
+ # For each bar, add the rotated value (as percentage), inside the bar
130
+ for i, p in enumerate(plt.gca().patches):
131
+ # TODO: For some reasons, there are 4 additional rectangles being
132
+ # plotted... I suspect it's because the dummy_df doesn't have the same
133
+ # shape as the df containing all the evaluation data...
134
+ if p.get_height() < 0.01:
135
+ continue
136
+ if i % 2 == 0:
137
+ value = '{:.1f}%'.format(100 * p.get_height())
138
+ else:
139
+ value = '{:.2f}'.format(p.get_height())
140
+
141
+ print(f'Plotting value: {p.get_height()} -> {value}')
142
+ x = p.get_x() + p.get_width() / 2
143
+ y = 0.4 # p.get_height() - p.get_height() / 2
144
+ plt.annotate(value, (x, y), ha='center', va='center', color='black', fontsize=10, rotation=90, alpha=0.8)
145
+
146
+ plt.savefig(f'plots/{title}.pdf', bbox_inches='tight')
147
+
148
+
149
+ def main():
150
+ active_col = 'Active (Dmax 0.6, pDC50 6.0)'
151
+ test_split = 0.1
152
+ n_models_for_test = 3
153
+
154
+ active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
155
+ report_base_name = f'{active_name}_test_split_{test_split}'
156
+
157
+ # Load the data
158
+ reports = {
159
+ 'cv_train': pd.read_csv(f'reports/report_cv_train_{report_base_name}.csv'),
160
+ 'test': pd.read_csv(f'reports/report_test_{report_base_name}.csv'),
161
+ 'ablation': pd.read_csv(f'reports/report_ablation_{report_base_name}.csv'),
162
+ 'hparam': pd.read_csv(f'reports/report_hparam_{report_base_name}.csv'),
163
+ }
164
+
165
+
166
+ # metrics = {}
167
+ # for i in range(n_models_for_test):
168
+ # for split_type in ['random', 'tanimoto', 'uniprot', 'e3_ligase']:
169
+ # logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
170
+ # metrics[f'{split_type}_{i}'] = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
171
+ # metrics[f'{split_type}_{i}']['model_id'] = i
172
+ # # Rename 'val_' columns to 'test_' columns
173
+ # metrics[f'{split_type}_{i}'] = metrics[f'{split_type}_{i}'].rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
174
+
175
+ # plot_metrics(metrics[f'{split_type}_{i}'], f'{split_type}_{i}')
176
+
177
+
178
+ df_val = reports['cv_train']
179
+ df_test = reports['test']
180
+ plot_report(df_val, df_test, title=f'{active_name}_metrics')
181
+
182
+
183
+ if __name__ == '__main__':
184
+ main()