Commit
·
74a86c6
1
Parent(s):
9692a92
Improved plotting code
Browse files- README.md +4 -0
- src/plot_experiment_results.py +184 -0
README.md
CHANGED
@@ -40,3 +40,7 @@ print(f'The given PROTAC is: {"active" if active_protac else "inactive"}')
|
|
40 |
|
41 |
> If you're coming from my [thesis repo](https://github.com/ribesstefano/Machine-Learning-for-Predicting-Targeted-Protein-Degradation), I just wanted to create a separate and "less generic" repo for fast prototyping new ideas.
|
42 |
> Stefano.
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
> If you're coming from my [thesis repo](https://github.com/ribesstefano/Machine-Learning-for-Predicting-Targeted-Protein-Degradation), I just wanted to create a separate and "less generic" repo for fast prototyping new ideas.
|
42 |
> Stefano.
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
> Why haven't you trained on more (i.e., the whole) data? We did, and we might just need _way_ more data to get better results...
|
src/plot_experiment_results.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import seaborn as sns
|
8 |
+
import pandas as pd
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
|
13 |
+
|
14 |
+
|
15 |
+
def plot_metrics(df, title):
|
16 |
+
# Clean the data
|
17 |
+
df = df.dropna(how='all', axis=1)
|
18 |
+
|
19 |
+
# Convert all columns to numeric, setting errors='coerce' to handle non-numeric data
|
20 |
+
df = df.apply(pd.to_numeric, errors='coerce')
|
21 |
+
|
22 |
+
# Group by 'epoch' and aggregate by mean
|
23 |
+
epoch_data = df.groupby('epoch').mean()
|
24 |
+
|
25 |
+
fig, ax = plt.subplots(3, 1, figsize=(10, 15))
|
26 |
+
|
27 |
+
# Plot training loss
|
28 |
+
ax[0].plot(epoch_data.index, epoch_data['train_loss_epoch'], label='Training Loss')
|
29 |
+
ax[0].plot(epoch_data.index, epoch_data['test_loss'], label='Test Loss', linestyle='--')
|
30 |
+
ax[0].set_ylabel('Loss')
|
31 |
+
ax[0].legend(loc='lower right')
|
32 |
+
ax[0].grid(axis='both', alpha=0.5)
|
33 |
+
|
34 |
+
# Plot training accuracy
|
35 |
+
ax[1].plot(epoch_data.index, epoch_data['train_acc_epoch'], label='Training Accuracy')
|
36 |
+
ax[1].plot(epoch_data.index, epoch_data['test_acc'], label='Test Accuracy', linestyle='--')
|
37 |
+
ax[1].set_ylabel('Accuracy')
|
38 |
+
ax[1].legend(loc='lower right')
|
39 |
+
ax[1].grid(axis='both', alpha=0.5)
|
40 |
+
|
41 |
+
# Plot training ROC-AUC
|
42 |
+
ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
|
43 |
+
ax[2].plot(epoch_data.index, epoch_data['test_roc_auc'], label='Test ROC-AUC', linestyle='--')
|
44 |
+
ax[2].set_ylabel('ROC-AUC')
|
45 |
+
ax[2].legend(loc='lower right')
|
46 |
+
ax[2].grid(axis='both', alpha=0.5)
|
47 |
+
|
48 |
+
# Set x-axis label
|
49 |
+
ax[2].set_xlabel('Epoch')
|
50 |
+
|
51 |
+
plt.title(title)
|
52 |
+
plt.tight_layout()
|
53 |
+
plt.savefig(f'plots/{title}_metrics.pdf', bbox_inches='tight')
|
54 |
+
|
55 |
+
|
56 |
+
def plot_report(df_cv, df_test, title=None):
|
57 |
+
|
58 |
+
# Extract and prepare CV data
|
59 |
+
cv_data = df_cv[['model_type', 'fold', 'val_acc', 'val_roc_auc', 'test_acc', 'test_roc_auc', 'split_type']]
|
60 |
+
cv_data = cv_data.melt(id_vars=['model_type', 'fold', 'split_type'], var_name='Metric', value_name='Score')
|
61 |
+
cv_data['Metric'] = cv_data['Metric'].replace({
|
62 |
+
'val_acc': 'Validation Accuracy',
|
63 |
+
'val_roc_auc': 'Validation ROC AUC',
|
64 |
+
'test_acc': 'Test Accuracy',
|
65 |
+
'test_roc_auc': 'Test ROC AUC'
|
66 |
+
})
|
67 |
+
cv_data['Stage'] = cv_data['Metric'].apply(lambda x: 'Validation' if 'Val' in x else 'Test')
|
68 |
+
|
69 |
+
# Extract and prepare test data
|
70 |
+
test_data = df_test[['model_type', 'test_acc', 'test_roc_auc', 'split_type']]
|
71 |
+
test_data = test_data.melt(id_vars=['model_type', 'split_type'], var_name='Metric', value_name='Score')
|
72 |
+
test_data['Metric'] = test_data['Metric'].replace({
|
73 |
+
'test_acc': 'Test Accuracy',
|
74 |
+
'test_roc_auc': 'Test ROC AUC'
|
75 |
+
})
|
76 |
+
test_data['Stage'] = 'Test'
|
77 |
+
|
78 |
+
# Combine CV and test data
|
79 |
+
combined_data = pd.concat([cv_data, test_data], ignore_index=True)
|
80 |
+
|
81 |
+
# Rename 'split_type' values according to a predefined map for clarity
|
82 |
+
group2name = {
|
83 |
+
'random': 'Standard Split',
|
84 |
+
'uniprot': 'Target Split',
|
85 |
+
'tanimoto': 'Similarity Split',
|
86 |
+
}
|
87 |
+
combined_data['Split Type'] = combined_data['split_type'].map(group2name)
|
88 |
+
|
89 |
+
# Add dummy model data
|
90 |
+
dummy_val_acc = []
|
91 |
+
dummy_test_acc = []
|
92 |
+
for i, group in enumerate(group2name.keys()):
|
93 |
+
# Get the majority class in group_df
|
94 |
+
group_df = df_cv[df_cv['split_type'] == group]
|
95 |
+
major_col = 'inactive' if group_df['val_inactive_perc'].mean() > 0.5 else 'active'
|
96 |
+
dummy_val_acc.append(group_df[f'val_{major_col}_perc'].mean())
|
97 |
+
|
98 |
+
group_df = df_test[df_test['split_type'] == group]
|
99 |
+
major_col = 'inactive' if group_df['test_inactive_perc'].mean() > 0.5 else 'active'
|
100 |
+
dummy_test_acc.append(group_df[f'test_{major_col}_perc'].mean())
|
101 |
+
|
102 |
+
dummy_scores = []
|
103 |
+
metrics = ['Validation Accuracy', 'Validation ROC AUC', 'Test Accuracy', 'Test ROC AUC']
|
104 |
+
for i in range(len(dummy_val_acc)):
|
105 |
+
for metric, score in zip(metrics, [dummy_val_acc[i], 0.5, dummy_test_acc[i], 0.5]):
|
106 |
+
dummy_scores.append({
|
107 |
+
'Experiment': i,
|
108 |
+
'Metric': metric,
|
109 |
+
'Score': score,
|
110 |
+
'Split Type': 'Dummy model',
|
111 |
+
})
|
112 |
+
dummy_model = pd.DataFrame(dummy_scores)
|
113 |
+
combined_data = pd.concat([combined_data, dummy_model], ignore_index=True)
|
114 |
+
|
115 |
+
# Plotting
|
116 |
+
plt.figure(figsize=(12, 6))
|
117 |
+
sns.barplot(data=combined_data, x='Metric', y='Score', hue='Split Type', errorbar='sd', palette=palette)
|
118 |
+
plt.title('')
|
119 |
+
plt.ylabel('')
|
120 |
+
plt.xlabel('')
|
121 |
+
plt.ylim(0, 1.0) # Assuming scores are normalized between 0 and 1
|
122 |
+
plt.grid(axis='y', alpha=0.5, linewidth=0.5)
|
123 |
+
|
124 |
+
# Make the y-axis as percentage
|
125 |
+
plt.gca().yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
|
126 |
+
# Plot the legend below the x-axis, outside the plot, and divided in two columns
|
127 |
+
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.08), ncol=4)
|
128 |
+
|
129 |
+
# For each bar, add the rotated value (as percentage), inside the bar
|
130 |
+
for i, p in enumerate(plt.gca().patches):
|
131 |
+
# TODO: For some reasons, there are 4 additional rectangles being
|
132 |
+
# plotted... I suspect it's because the dummy_df doesn't have the same
|
133 |
+
# shape as the df containing all the evaluation data...
|
134 |
+
if p.get_height() < 0.01:
|
135 |
+
continue
|
136 |
+
if i % 2 == 0:
|
137 |
+
value = '{:.1f}%'.format(100 * p.get_height())
|
138 |
+
else:
|
139 |
+
value = '{:.2f}'.format(p.get_height())
|
140 |
+
|
141 |
+
print(f'Plotting value: {p.get_height()} -> {value}')
|
142 |
+
x = p.get_x() + p.get_width() / 2
|
143 |
+
y = 0.4 # p.get_height() - p.get_height() / 2
|
144 |
+
plt.annotate(value, (x, y), ha='center', va='center', color='black', fontsize=10, rotation=90, alpha=0.8)
|
145 |
+
|
146 |
+
plt.savefig(f'plots/{title}.pdf', bbox_inches='tight')
|
147 |
+
|
148 |
+
|
149 |
+
def main():
|
150 |
+
active_col = 'Active (Dmax 0.6, pDC50 6.0)'
|
151 |
+
test_split = 0.1
|
152 |
+
n_models_for_test = 3
|
153 |
+
|
154 |
+
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
155 |
+
report_base_name = f'{active_name}_test_split_{test_split}'
|
156 |
+
|
157 |
+
# Load the data
|
158 |
+
reports = {
|
159 |
+
'cv_train': pd.read_csv(f'reports/report_cv_train_{report_base_name}.csv'),
|
160 |
+
'test': pd.read_csv(f'reports/report_test_{report_base_name}.csv'),
|
161 |
+
'ablation': pd.read_csv(f'reports/report_ablation_{report_base_name}.csv'),
|
162 |
+
'hparam': pd.read_csv(f'reports/report_hparam_{report_base_name}.csv'),
|
163 |
+
}
|
164 |
+
|
165 |
+
|
166 |
+
# metrics = {}
|
167 |
+
# for i in range(n_models_for_test):
|
168 |
+
# for split_type in ['random', 'tanimoto', 'uniprot', 'e3_ligase']:
|
169 |
+
# logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
|
170 |
+
# metrics[f'{split_type}_{i}'] = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
|
171 |
+
# metrics[f'{split_type}_{i}']['model_id'] = i
|
172 |
+
# # Rename 'val_' columns to 'test_' columns
|
173 |
+
# metrics[f'{split_type}_{i}'] = metrics[f'{split_type}_{i}'].rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
|
174 |
+
|
175 |
+
# plot_metrics(metrics[f'{split_type}_{i}'], f'{split_type}_{i}')
|
176 |
+
|
177 |
+
|
178 |
+
df_val = reports['cv_train']
|
179 |
+
df_test = reports['test']
|
180 |
+
plot_report(df_val, df_test, title=f'{active_name}_metrics')
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == '__main__':
|
184 |
+
main()
|