wi-lab commited on
Commit
7538df1
·
verified ·
1 Parent(s): 9149b8d

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +445 -445
train.py CHANGED
@@ -1,446 +1,446 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Fri Dec 20 09:32:12 2024
4
-
5
- This script contains the LWM pre-training and task-specific fine-tuning functions.
6
-
7
- @author: Sadjad Alikhani
8
- """
9
- import torch
10
- import torch.nn as nn
11
- from tqdm import tqdm
12
- import matplotlib.pyplot as plt
13
- import os
14
- import csv
15
- from utils import count_parameters
16
- import time
17
- #%% LOSS FUNCTION
18
- def nmse_loss(y_pred, y_true):
19
- y_pred_flat = y_pred.view(y_pred.size(0), -1)
20
- y_true_flat = y_true.view(y_true.size(0), -1)
21
- mse = torch.sum((y_true_flat - y_pred_flat)**2, dim=-1)
22
- normalization = torch.sum(y_true_flat**2, dim=-1)
23
- return mse / normalization
24
- #%%
25
- def train_lwm(model, train_loaders, val_loaders, optimizer, scheduler, epochs, device, save_dir="models", log_file="training_log.csv"):
26
-
27
- if not os.path.exists(save_dir):
28
- os.makedirs(save_dir)
29
-
30
- # Initialize CSV log
31
- if not os.path.exists(log_file):
32
- with open(log_file, mode='w', newline='') as file:
33
- writer = csv.writer(file)
34
- writer.writerow(["Epoch", "Train NMSE", "Validation NMSE", "Learning Rate", "Best Model"])
35
-
36
- train_nmse_losses = []
37
- val_nmse_losses = []
38
- best_val_nmse = float('inf')
39
-
40
- for epoch in range(epochs):
41
- model.train()
42
- train_nmse = 0.0
43
- train_samples = 0
44
-
45
- # Training loop across all buckets
46
- print(f"\nEpoch {epoch + 1}/{epochs} [Training]")
47
- for length, train_loader in train_loaders.items():
48
- print(f"Processing sequences of length {length}")
49
- with tqdm(train_loader, desc=f"Length {length} [Training]", unit="batch") as t:
50
- for batch in t:
51
- # train_batches += 1
52
- optimizer.zero_grad()
53
-
54
- # Move data to device
55
- input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
56
-
57
- # Forward pass
58
- logits_lm, _, _ = model(input_ids, masked_pos)
59
-
60
- # Compute NMSE
61
- loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
62
- loss.backward()
63
- optimizer.step()
64
- scheduler.step()
65
-
66
- train_nmse += loss.item()
67
- train_samples += input_ids.shape[0]
68
-
69
- # Update progress bar
70
- t.set_postfix({"nmse": train_nmse/train_samples, "lr": scheduler.get_last_lr()[0]})
71
-
72
- # Average NMSE across training batches
73
- train_nmse /= max(train_samples, 1)
74
- train_nmse_losses.append(train_nmse)
75
-
76
- if epoch % 2 == 0:
77
- # Validation loop across all buckets
78
- model.eval()
79
- val_nmse = 0.0
80
- val_samples = 0
81
- with torch.no_grad():
82
- print(f"\nEpoch {epoch + 1}/{epochs} [Validation]")
83
- for length, val_loader in val_loaders.items():
84
- print(f"Processing sequences of length {length}")
85
- with tqdm(val_loader, desc=f"Length {length} [Validation]", unit="batch") as t:
86
- for batch in t:
87
- # val_batches += 1
88
-
89
- # Move data to device
90
- input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
91
-
92
- # Forward pass
93
- logits_lm, _, _ = model(input_ids, masked_pos)
94
-
95
- # Compute NMSE
96
- loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
97
- val_nmse += loss.item()
98
- val_samples += input_ids.shape[0]
99
-
100
- # Update progress bar
101
- t.set_postfix({"nmse": val_nmse/val_samples})
102
-
103
- # Average NMSE across validation batches
104
- val_nmse /= max(val_samples, 1)
105
- val_nmse_losses.append(val_nmse)
106
-
107
- # Save model if validation NMSE improves
108
- is_best_model = False
109
- if val_nmse < best_val_nmse:
110
- best_val_nmse = val_nmse
111
- model_path = os.path.join(save_dir, f"lwm_epoch{epoch+1}_train{train_nmse:.4f}_val{val_nmse:.4f}.pth")
112
- torch.save(model.state_dict(), model_path)
113
- print(f"Model saved: {model_path}")
114
- is_best_model = True
115
-
116
- # Log the results
117
- print(f" Train NMSE: {train_nmse:.4f}")
118
- print(f" Validation NMSE: {val_nmse:.4f}")
119
- print(f" Learning Rate: {scheduler.get_last_lr()[0]:.6e}")
120
-
121
- # Append to CSV log
122
- with open(log_file, mode='a', newline='') as file:
123
- writer = csv.writer(file)
124
- writer.writerow([epoch + 1, train_nmse, val_nmse, scheduler.get_last_lr()[0], is_best_model])
125
-
126
- # Plot losses after each epoch
127
- plt.figure(figsize=(10, 6))
128
- plt.plot(range(1, len(train_nmse_losses) + 1), train_nmse_losses, label="Train NMSE")
129
- plt.plot(range(1, len(val_nmse_losses) + 1), val_nmse_losses, label="Validation NMSE")
130
- plt.xlabel("Epochs")
131
- plt.ylabel("NMSE")
132
- plt.title("Training and Validation NMSE Loss")
133
- plt.legend()
134
- plt.grid(True)
135
- plt.show()
136
-
137
- print("Training and validation complete.")
138
- return model
139
- #%% FINE-TUNE
140
- from torch.cuda.amp import GradScaler, autocast
141
-
142
- # Define the ClassificationHead
143
- class ClassificationHead(nn.Module):
144
- def __init__(self, input_dim, num_classes):
145
- super().__init__()
146
- self.fc = nn.Linear(input_dim, num_classes)
147
-
148
- def forward(self, x):
149
- return self.fc(x)
150
-
151
-
152
- # Define the RegressionHead
153
- class RegressionHead(nn.Module):
154
- def __init__(self, input_dim):
155
- super().__init__()
156
- self.fc = nn.Linear(input_dim, 1)
157
-
158
- def forward(self, x):
159
- return self.fc(x)
160
-
161
- class CustomClassificationHead(nn.Module):
162
- def __init__(self, input_dim, num_classes):
163
-
164
- super().__init__()
165
- self.classifier = nn.Sequential(
166
- nn.Linear(input_dim, 512),
167
- nn.BatchNorm1d(512),
168
- nn.ReLU(),
169
- nn.Dropout(0.1),
170
- nn.Linear(512, 256),
171
- nn.BatchNorm1d(256),
172
- nn.ReLU(),
173
- nn.Dropout(0.1),
174
- nn.Linear(256, 128),
175
- nn.BatchNorm1d(128),
176
- nn.ReLU(),
177
- # nn.Dropout(0.1),
178
- nn.Linear(128, num_classes)
179
- )
180
-
181
- def forward(self, x):
182
- return self.classifier(x)
183
-
184
- class CustomRegressionHead(nn.Module):
185
- def __init__(self, input_dim, output_dim):
186
-
187
- super().__init__()
188
- self.regressor = nn.Sequential(
189
- nn.Linear(input_dim, 512),
190
- nn.BatchNorm1d(512),
191
- nn.ReLU(),
192
- nn.Dropout(0.1),
193
- nn.Linear(512, 256),
194
- nn.BatchNorm1d(256),
195
- nn.ReLU(),
196
- nn.Dropout(0.1),
197
- nn.Linear(256, output_dim)
198
- )
199
-
200
- def forward(self, x):
201
- return self.regressor(x)
202
-
203
-
204
- def custom_heads(input_dim, num_classes=None, output_dim=None, task_type="classification"):
205
- """
206
- Creates a custom head for classification or regression tasks.
207
- Users should modify the class implementations for further customization.
208
-
209
- Args:
210
- input_dim (int): Input dimension of the head.
211
- num_classes (int): Number of classes for classification tasks. Ignored for regression.
212
- task_type (str): "classification" or "regression".
213
-
214
- Returns:
215
- nn.Module: Custom head for the specified task.
216
- """
217
- if task_type == "classification":
218
- if num_classes is None:
219
- raise ValueError("num_classes must be specified for classification tasks.")
220
- return CustomClassificationHead(input_dim=input_dim, num_classes=num_classes)
221
- elif task_type == "regression":
222
- return CustomRegressionHead(input_dim=input_dim, output_dim=output_dim)
223
- else:
224
- raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
225
- #%%
226
- # Fine-tuning wrapper for the base model
227
- class FineTuningWrapper(nn.Module):
228
- def __init__(self, model, task_head, fine_tune_layers="full"):
229
- super().__init__()
230
- self.model = model
231
- self.task_head = task_head
232
-
233
- # Freeze all layers initially
234
- for param in self.model.parameters():
235
- param.requires_grad = False
236
-
237
- # Handle fine-tuning layers
238
- if fine_tune_layers is not None:
239
- if fine_tune_layers == "full":
240
- # Unfreeze all layers if "all" is specified
241
- for param in self.model.parameters():
242
- param.requires_grad = True
243
- else:
244
- # Get a list of all available layer names in the model
245
- available_layers = [name for name, _ in self.model.named_parameters()]
246
-
247
- # Validate that specified layers exist in the model
248
- for layer in fine_tune_layers:
249
- if not any(layer in lname for lname in available_layers):
250
- raise ValueError(
251
- f"Layer '{layer}' not found in the model. "
252
- f"Available layers: {available_layers}"
253
- )
254
-
255
- # Unfreeze only the specified layers
256
- for name, param in self.model.named_parameters():
257
- if any(layer in name for layer in fine_tune_layers):
258
- param.requires_grad = True
259
-
260
- def forward(self, x, input_type="cls_emb"):
261
- if input_type == "raw":
262
- task_input = x.view(x.size(0), -1)
263
- else:
264
- embeddings, attn_maps = self.model(x) # Get embeddings from the base model
265
- if input_type == "cls_emb":
266
- task_input = embeddings[:, 0, :] # CLS token
267
- elif input_type == "chs_emb":
268
- chs_emb = embeddings[:, 1:, :]
269
- task_input = chs_emb.view(chs_emb.size(0), -1) # embeddings.mean(dim=1) # Mean pooling over channel embeddings
270
-
271
- return self.task_head(task_input), 0 if input_type=="raw" else attn_maps
272
- #%%
273
- # Fine-tuning function
274
- from sklearn.metrics import f1_score
275
- def finetune(
276
- base_model,
277
- train_loader,
278
- val_loader=None,
279
- task_type="classification",
280
- input_type="cls_emb",
281
- num_classes=None,
282
- output_dim=None,
283
- use_custom_head=False,
284
- fine_tune_layers=None,
285
- optimizer_config=None,
286
- criterion=None,
287
- epochs=10,
288
- device="cuda",
289
- task="Beam Prediction"
290
- ):
291
- """
292
- Configures and fine-tunes the base model with user-defined settings, saving results and models.
293
- """
294
- # Create results folder
295
- time_now = f"{time.time():.0f}"
296
- results_folder = f"results/{task}/{time_now}"
297
- os.makedirs(results_folder, exist_ok=True)
298
- log_file = os.path.join(results_folder, "training_log.csv")
299
-
300
- # Initialize the CSV log
301
- with open(log_file, mode='w', newline='') as file:
302
- writer = csv.writer(file)
303
- writer.writerow(["Task", "Input", "Epoch", "Train Loss", "Validation Loss", "F1-Score (Classification)", "Learning Rate", "Time"])
304
-
305
- for batch in val_loader:
306
- input_data, targets = batch[0].to(device), batch[1].to(device)
307
- break
308
-
309
- if input_type == "cls_emb":
310
- n_patches = 1
311
- patch_size = 128
312
- elif input_type == "channel_emb":
313
- n_patches = input_data.shape[1]-1
314
- patch_size = 128
315
- elif input_type == "raw":
316
- n_patches = input_data.shape[1]
317
- patch_size = 32
318
- # patch_size = 1
319
-
320
- if use_custom_head:
321
- custom_head = custom_heads(input_dim=n_patches*patch_size,
322
- num_classes=num_classes,
323
- output_dim=output_dim,
324
- task_type=task_type)
325
-
326
- # Handle DataParallel models
327
- if isinstance(base_model, nn.DataParallel):
328
- base_model = base_model.module
329
-
330
- # Set up the task-specific head
331
- if use_custom_head:
332
- task_head = custom_head
333
- elif task_type == "classification":
334
- if num_classes is None:
335
- raise ValueError("num_classes must be specified for classification tasks.")
336
- task_head = ClassificationHead(input_dim=n_patches*patch_size, num_classes=num_classes) # input_dim=base_model.embedding.d_model
337
- elif task_type == "regression":
338
- task_head = RegressionHead(input_dim=n_patches*patch_size) # input_dim=base_model.embedding.d_model
339
- else:
340
- raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
341
-
342
- # Wrap the model with the fine-tuning head
343
- wrapper = FineTuningWrapper(base_model, task_head, fine_tune_layers=fine_tune_layers)
344
- wrapper = wrapper.to(device)
345
-
346
- print(f'Number of head parameters: {count_parameters(wrapper)}')
347
-
348
- # Set default optimizer config if not provided
349
- if optimizer_config is None:
350
- optimizer_config = {"lr": 1e-4}
351
- # Set up the optimizer
352
- optimizer = torch.optim.Adam(wrapper.parameters(), **optimizer_config)
353
- # Set up the scheduler for learning rate decay
354
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2) # Example: Reduce LR by 10x every 10 epochs
355
-
356
- # Set up the loss criterion
357
- if criterion is None:
358
- criterion = nn.CrossEntropyLoss() if task_type == "classification" else nn.MSELoss()
359
-
360
- scaler = GradScaler()
361
- train_losses, val_losses, f1_scores = [], [], []
362
- best_val_loss = float("inf")
363
- best_model_path = None
364
-
365
- for epoch in range(epochs):
366
- # Training loop
367
- wrapper.train()
368
- epoch_loss = 0.0
369
-
370
- with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}") as progress_bar:
371
- for batch in progress_bar:
372
- input_data, targets = batch[0].to(device), batch[1].to(device)
373
- optimizer.zero_grad()
374
-
375
- with autocast():
376
- outputs, attn_maps = wrapper(input_data, input_type=input_type)
377
- loss = criterion(outputs, targets)
378
-
379
- scaler.scale(loss).backward()
380
- scaler.step(optimizer)
381
- scaler.update()
382
-
383
- epoch_loss += loss.item()
384
- progress_bar.set_postfix({"Loss": loss.item()})
385
-
386
- avg_train_loss = epoch_loss / len(train_loader)
387
- train_losses.append(avg_train_loss)
388
-
389
- # Validation loop
390
- if val_loader:
391
- wrapper.eval()
392
- val_loss = 0.0
393
- all_preds, all_targets = [], []
394
-
395
- with torch.no_grad():
396
- for batch in val_loader:
397
- input_data, targets = batch[0].to(device), batch[1].to(device)
398
- with autocast():
399
- outputs, _ = wrapper(input_data, input_type=input_type)
400
- loss = criterion(outputs, targets)
401
-
402
- val_loss += loss.item()
403
-
404
- if task_type == "classification":
405
- preds = torch.argmax(outputs, dim=1).cpu().numpy()
406
- all_preds.extend(preds)
407
- all_targets.extend(targets.cpu().numpy())
408
-
409
- avg_val_loss = val_loss / len(val_loader)
410
- val_losses.append(avg_val_loss)
411
-
412
- time_now = f"{time.time():.0f}"
413
- # Save the best model
414
- if avg_val_loss < best_val_loss:
415
- best_val_loss = avg_val_loss
416
- best_model_path = os.path.join(results_folder, f"{input_type}_epoch{epoch+1}_valLoss{avg_val_loss:.4f}_{time_now}.pth")
417
- torch.save(wrapper.state_dict(), best_model_path)
418
- print(f"Model saved at {best_model_path} with validation loss: {best_val_loss:.4f}")
419
-
420
- # Compute F1-score for classification tasks
421
- f1 = None
422
- if task_type == "classification":
423
- f1 = f1_score(all_targets, all_preds, average="macro")
424
- print(f"Epoch {epoch + 1}, Validation F1-Score: {f1:.4f}")
425
- f1_scores.append(f1)
426
-
427
- scheduler.step()
428
-
429
- # Log results
430
- with open(log_file, mode='a', newline='') as file:
431
- writer = csv.writer(file)
432
- writer.writerow([task, input_type, epoch + 1, avg_train_loss, avg_val_loss, f1 if f1 is not None else "-", scheduler.get_last_lr()[0], f"{time_now}"])
433
-
434
- # Plot training and validation losses
435
- plt.figure(figsize=(10, 6))
436
- plt.plot(range(1, epochs + 1), train_losses, label="Training Loss")
437
- plt.plot(range(1, epochs + 1), val_losses, label="Validation Loss", linestyle="--")
438
- plt.xlabel("Epochs")
439
- plt.ylabel("Loss")
440
- plt.title("Training and Validation Loss")
441
- plt.legend()
442
- plt.grid(True)
443
- # plt.savefig(os.path.join(results_folder, "loss_curve.png"))
444
- plt.show()
445
-
446
  return wrapper, best_model_path, train_losses, val_losses, f1_scores if task_type == "classification" else 0, attn_maps
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Dec 20 09:32:12 2024
4
+
5
+ This script contains the LWM pre-training and task-specific fine-tuning functions.
6
+
7
+ @author: Sadjad Alikhani
8
+ """
9
+ import torch
10
+ import torch.nn as nn
11
+ from tqdm import tqdm
12
+ import matplotlib.pyplot as plt
13
+ import os
14
+ import csv
15
+ from utils import count_parameters
16
+ import time
17
+ #%% LOSS FUNCTION
18
+ def nmse_loss(y_pred, y_true):
19
+ y_pred_flat = y_pred.view(y_pred.size(0), -1)
20
+ y_true_flat = y_true.view(y_true.size(0), -1)
21
+ mse = torch.sum((y_true_flat - y_pred_flat)**2, dim=-1)
22
+ normalization = torch.sum(y_true_flat**2, dim=-1)
23
+ return mse / normalization
24
+ #%%
25
+ def train_lwm(model, train_loaders, val_loaders, optimizer, scheduler, epochs, device, save_dir="models", log_file="training_log.csv"):
26
+
27
+ if not os.path.exists(save_dir):
28
+ os.makedirs(save_dir)
29
+
30
+ # Initialize CSV log
31
+ if not os.path.exists(log_file):
32
+ with open(log_file, mode='w', newline='') as file:
33
+ writer = csv.writer(file)
34
+ writer.writerow(["Epoch", "Train NMSE", "Validation NMSE", "Learning Rate", "Best Model"])
35
+
36
+ train_nmse_losses = []
37
+ val_nmse_losses = []
38
+ best_val_nmse = float('inf')
39
+
40
+ for epoch in range(epochs):
41
+ model.train()
42
+ train_nmse = 0.0
43
+ train_samples = 0
44
+
45
+ # Training loop across all buckets
46
+ print(f"\nEpoch {epoch + 1}/{epochs} [Training]")
47
+ for length, train_loader in train_loaders.items():
48
+ print(f"Processing sequences of length {length}")
49
+ with tqdm(train_loader, desc=f"Length {length} [Training]", unit="batch") as t:
50
+ for batch in t:
51
+ # train_batches += 1
52
+ optimizer.zero_grad()
53
+
54
+ # Move data to device
55
+ input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
56
+
57
+ # Forward pass
58
+ logits_lm, _, _ = model(input_ids, masked_pos)
59
+
60
+ # Compute NMSE
61
+ loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
62
+ loss.backward()
63
+ optimizer.step()
64
+ scheduler.step()
65
+
66
+ train_nmse += loss.item()
67
+ train_samples += input_ids.shape[0]
68
+
69
+ # Update progress bar
70
+ t.set_postfix({"nmse": train_nmse/train_samples, "lr": scheduler.get_last_lr()[0]})
71
+
72
+ # Average NMSE across training batches
73
+ train_nmse /= max(train_samples, 1)
74
+ train_nmse_losses.append(train_nmse)
75
+
76
+ if epoch % 2 == 0:
77
+ # Validation loop across all buckets
78
+ model.eval()
79
+ val_nmse = 0.0
80
+ val_samples = 0
81
+ with torch.no_grad():
82
+ print(f"\nEpoch {epoch + 1}/{epochs} [Validation]")
83
+ for length, val_loader in val_loaders.items():
84
+ print(f"Processing sequences of length {length}")
85
+ with tqdm(val_loader, desc=f"Length {length} [Validation]", unit="batch") as t:
86
+ for batch in t:
87
+ # val_batches += 1
88
+
89
+ # Move data to device
90
+ input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
91
+
92
+ # Forward pass
93
+ logits_lm, _, _ = model(input_ids, masked_pos)
94
+
95
+ # Compute NMSE
96
+ loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
97
+ val_nmse += loss.item()
98
+ val_samples += input_ids.shape[0]
99
+
100
+ # Update progress bar
101
+ t.set_postfix({"nmse": val_nmse/val_samples})
102
+
103
+ # Average NMSE across validation batches
104
+ val_nmse /= max(val_samples, 1)
105
+ val_nmse_losses.append(val_nmse)
106
+
107
+ # Save model if validation NMSE improves
108
+ is_best_model = False
109
+ if val_nmse < best_val_nmse:
110
+ best_val_nmse = val_nmse
111
+ model_path = os.path.join(save_dir, f"lwm_epoch{epoch+1}_train{train_nmse:.4f}_val{val_nmse:.4f}.pth")
112
+ torch.save(model.state_dict(), model_path)
113
+ print(f"Model saved: {model_path}")
114
+ is_best_model = True
115
+
116
+ # Log the results
117
+ print(f" Train NMSE: {train_nmse:.4f}")
118
+ print(f" Validation NMSE: {val_nmse:.4f}")
119
+ print(f" Learning Rate: {scheduler.get_last_lr()[0]:.6e}")
120
+
121
+ # Append to CSV log
122
+ with open(log_file, mode='a', newline='') as file:
123
+ writer = csv.writer(file)
124
+ writer.writerow([epoch + 1, train_nmse, val_nmse, scheduler.get_last_lr()[0], is_best_model])
125
+
126
+ # Plot losses after each epoch
127
+ plt.figure(figsize=(10, 6))
128
+ plt.plot(range(1, len(train_nmse_losses) + 1), train_nmse_losses, label="Train NMSE")
129
+ plt.plot(range(1, len(val_nmse_losses) + 1), val_nmse_losses, label="Validation NMSE")
130
+ plt.xlabel("Epochs")
131
+ plt.ylabel("NMSE")
132
+ plt.title("Training and Validation NMSE Loss")
133
+ plt.legend()
134
+ plt.grid(True)
135
+ plt.show()
136
+
137
+ print("Training and validation complete.")
138
+ return model
139
+ #%% FINE-TUNE
140
+ from torch.cuda.amp import GradScaler, autocast
141
+
142
+ # Define the ClassificationHead
143
+ class ClassificationHead(nn.Module):
144
+ def __init__(self, input_dim, num_classes):
145
+ super().__init__()
146
+ self.fc = nn.Linear(input_dim, num_classes)
147
+
148
+ def forward(self, x):
149
+ return self.fc(x)
150
+
151
+
152
+ # Define the RegressionHead
153
+ class RegressionHead(nn.Module):
154
+ def __init__(self, input_dim):
155
+ super().__init__()
156
+ self.fc = nn.Linear(input_dim, 1)
157
+
158
+ def forward(self, x):
159
+ return self.fc(x)
160
+
161
+ class CustomClassificationHead(nn.Module):
162
+ def __init__(self, input_dim, num_classes):
163
+
164
+ super().__init__()
165
+ self.classifier = nn.Sequential(
166
+ nn.Linear(input_dim, 512),
167
+ nn.BatchNorm1d(512),
168
+ nn.ReLU(),
169
+ nn.Dropout(0.1),
170
+ nn.Linear(512, 256),
171
+ nn.BatchNorm1d(256),
172
+ nn.ReLU(),
173
+ nn.Dropout(0.1),
174
+ nn.Linear(256, 128),
175
+ nn.BatchNorm1d(128),
176
+ nn.ReLU(),
177
+ # nn.Dropout(0.1),
178
+ nn.Linear(128, num_classes)
179
+ )
180
+
181
+ def forward(self, x):
182
+ return self.classifier(x)
183
+
184
+ class CustomRegressionHead(nn.Module):
185
+ def __init__(self, input_dim, output_dim):
186
+
187
+ super().__init__()
188
+ self.regressor = nn.Sequential(
189
+ nn.Linear(input_dim, 512),
190
+ nn.BatchNorm1d(512),
191
+ nn.ReLU(),
192
+ nn.Dropout(0.1),
193
+ nn.Linear(512, 256),
194
+ nn.BatchNorm1d(256),
195
+ nn.ReLU(),
196
+ nn.Dropout(0.1),
197
+ nn.Linear(256, output_dim)
198
+ )
199
+
200
+ def forward(self, x):
201
+ return self.regressor(x)
202
+
203
+
204
+ def custom_heads(input_dim, num_classes=None, output_dim=None, task_type="classification"):
205
+ """
206
+ Creates a custom head for classification or regression tasks.
207
+ Users should modify the class implementations for further customization.
208
+
209
+ Args:
210
+ input_dim (int): Input dimension of the head.
211
+ num_classes (int): Number of classes for classification tasks. Ignored for regression.
212
+ task_type (str): "classification" or "regression".
213
+
214
+ Returns:
215
+ nn.Module: Custom head for the specified task.
216
+ """
217
+ if task_type == "classification":
218
+ if num_classes is None:
219
+ raise ValueError("num_classes must be specified for classification tasks.")
220
+ return CustomClassificationHead(input_dim=input_dim, num_classes=num_classes)
221
+ elif task_type == "regression":
222
+ return CustomRegressionHead(input_dim=input_dim, output_dim=output_dim)
223
+ else:
224
+ raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
225
+ #%%
226
+ # Fine-tuning wrapper for the base model
227
+ class FineTuningWrapper(nn.Module):
228
+ def __init__(self, model, task_head, fine_tune_layers="full"):
229
+ super().__init__()
230
+ self.model = model
231
+ self.task_head = task_head
232
+
233
+ # Freeze all layers initially
234
+ for param in self.model.parameters():
235
+ param.requires_grad = False
236
+
237
+ # Handle fine-tuning layers
238
+ if fine_tune_layers is not None:
239
+ if fine_tune_layers == "full":
240
+ # Unfreeze all layers if "all" is specified
241
+ for param in self.model.parameters():
242
+ param.requires_grad = True
243
+ else:
244
+ # Get a list of all available layer names in the model
245
+ available_layers = [name for name, _ in self.model.named_parameters()]
246
+
247
+ # Validate that specified layers exist in the model
248
+ for layer in fine_tune_layers:
249
+ if not any(layer in lname for lname in available_layers):
250
+ raise ValueError(
251
+ f"Layer '{layer}' not found in the model. "
252
+ f"Available layers: {available_layers}"
253
+ )
254
+
255
+ # Unfreeze only the specified layers
256
+ for name, param in self.model.named_parameters():
257
+ if any(layer in name for layer in fine_tune_layers):
258
+ param.requires_grad = True
259
+
260
+ def forward(self, x, input_type="cls_emb"):
261
+ if input_type == "raw":
262
+ task_input = x.view(x.size(0), -1)
263
+ else:
264
+ embeddings, attn_maps = self.model(x) # Get embeddings from the base model
265
+ if input_type == "cls_emb":
266
+ task_input = embeddings[:, 0, :] # CLS token
267
+ elif input_type == "channel_emb":
268
+ chs_emb = embeddings[:, 1:, :]
269
+ task_input = chs_emb.view(chs_emb.size(0), -1) # embeddings.mean(dim=1) # Mean pooling over channel embeddings
270
+
271
+ return self.task_head(task_input), 0 if input_type=="raw" else attn_maps
272
+ #%%
273
+ # Fine-tuning function
274
+ from sklearn.metrics import f1_score
275
+ def finetune(
276
+ base_model,
277
+ train_loader,
278
+ val_loader=None,
279
+ task_type="classification",
280
+ input_type="cls_emb",
281
+ num_classes=None,
282
+ output_dim=None,
283
+ use_custom_head=False,
284
+ fine_tune_layers=None,
285
+ optimizer_config=None,
286
+ criterion=None,
287
+ epochs=10,
288
+ device="cuda",
289
+ task="Beam Prediction"
290
+ ):
291
+ """
292
+ Configures and fine-tunes the base model with user-defined settings, saving results and models.
293
+ """
294
+ # Create results folder
295
+ time_now = f"{time.time():.0f}"
296
+ results_folder = f"results/{task}/{time_now}"
297
+ os.makedirs(results_folder, exist_ok=True)
298
+ log_file = os.path.join(results_folder, "training_log.csv")
299
+
300
+ # Initialize the CSV log
301
+ with open(log_file, mode='w', newline='') as file:
302
+ writer = csv.writer(file)
303
+ writer.writerow(["Task", "Input", "Epoch", "Train Loss", "Validation Loss", "F1-Score (Classification)", "Learning Rate", "Time"])
304
+
305
+ for batch in val_loader:
306
+ input_data, targets = batch[0].to(device), batch[1].to(device)
307
+ break
308
+
309
+ if input_type == "cls_emb":
310
+ n_patches = 1
311
+ patch_size = 128
312
+ elif input_type == "channel_emb":
313
+ n_patches = input_data.shape[1]-1
314
+ patch_size = 128
315
+ elif input_type == "raw":
316
+ n_patches = input_data.shape[1]
317
+ patch_size = 32
318
+ # patch_size = 1
319
+
320
+ if use_custom_head:
321
+ custom_head = custom_heads(input_dim=n_patches*patch_size,
322
+ num_classes=num_classes,
323
+ output_dim=output_dim,
324
+ task_type=task_type)
325
+
326
+ # Handle DataParallel models
327
+ if isinstance(base_model, nn.DataParallel):
328
+ base_model = base_model.module
329
+
330
+ # Set up the task-specific head
331
+ if use_custom_head:
332
+ task_head = custom_head
333
+ elif task_type == "classification":
334
+ if num_classes is None:
335
+ raise ValueError("num_classes must be specified for classification tasks.")
336
+ task_head = ClassificationHead(input_dim=n_patches*patch_size, num_classes=num_classes) # input_dim=base_model.embedding.d_model
337
+ elif task_type == "regression":
338
+ task_head = RegressionHead(input_dim=n_patches*patch_size) # input_dim=base_model.embedding.d_model
339
+ else:
340
+ raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
341
+
342
+ # Wrap the model with the fine-tuning head
343
+ wrapper = FineTuningWrapper(base_model, task_head, fine_tune_layers=fine_tune_layers)
344
+ wrapper = wrapper.to(device)
345
+
346
+ print(f'Number of head parameters: {count_parameters(wrapper)}')
347
+
348
+ # Set default optimizer config if not provided
349
+ if optimizer_config is None:
350
+ optimizer_config = {"lr": 1e-4}
351
+ # Set up the optimizer
352
+ optimizer = torch.optim.Adam(wrapper.parameters(), **optimizer_config)
353
+ # Set up the scheduler for learning rate decay
354
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2) # Example: Reduce LR by 10x every 10 epochs
355
+
356
+ # Set up the loss criterion
357
+ if criterion is None:
358
+ criterion = nn.CrossEntropyLoss() if task_type == "classification" else nn.MSELoss()
359
+
360
+ scaler = GradScaler()
361
+ train_losses, val_losses, f1_scores = [], [], []
362
+ best_val_loss = float("inf")
363
+ best_model_path = None
364
+
365
+ for epoch in range(epochs):
366
+ # Training loop
367
+ wrapper.train()
368
+ epoch_loss = 0.0
369
+
370
+ with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}") as progress_bar:
371
+ for batch in progress_bar:
372
+ input_data, targets = batch[0].to(device), batch[1].to(device)
373
+ optimizer.zero_grad()
374
+
375
+ with autocast():
376
+ outputs, attn_maps = wrapper(input_data, input_type=input_type)
377
+ loss = criterion(outputs, targets)
378
+
379
+ scaler.scale(loss).backward()
380
+ scaler.step(optimizer)
381
+ scaler.update()
382
+
383
+ epoch_loss += loss.item()
384
+ progress_bar.set_postfix({"Loss": loss.item()})
385
+
386
+ avg_train_loss = epoch_loss / len(train_loader)
387
+ train_losses.append(avg_train_loss)
388
+
389
+ # Validation loop
390
+ if val_loader:
391
+ wrapper.eval()
392
+ val_loss = 0.0
393
+ all_preds, all_targets = [], []
394
+
395
+ with torch.no_grad():
396
+ for batch in val_loader:
397
+ input_data, targets = batch[0].to(device), batch[1].to(device)
398
+ with autocast():
399
+ outputs, _ = wrapper(input_data, input_type=input_type)
400
+ loss = criterion(outputs, targets)
401
+
402
+ val_loss += loss.item()
403
+
404
+ if task_type == "classification":
405
+ preds = torch.argmax(outputs, dim=1).cpu().numpy()
406
+ all_preds.extend(preds)
407
+ all_targets.extend(targets.cpu().numpy())
408
+
409
+ avg_val_loss = val_loss / len(val_loader)
410
+ val_losses.append(avg_val_loss)
411
+
412
+ time_now = f"{time.time():.0f}"
413
+ # Save the best model
414
+ if avg_val_loss < best_val_loss:
415
+ best_val_loss = avg_val_loss
416
+ best_model_path = os.path.join(results_folder, f"{input_type}_epoch{epoch+1}_valLoss{avg_val_loss:.4f}_{time_now}.pth")
417
+ torch.save(wrapper.state_dict(), best_model_path)
418
+ print(f"Model saved at {best_model_path} with validation loss: {best_val_loss:.4f}")
419
+
420
+ # Compute F1-score for classification tasks
421
+ f1 = None
422
+ if task_type == "classification":
423
+ f1 = f1_score(all_targets, all_preds, average="macro")
424
+ print(f"Epoch {epoch + 1}, Validation F1-Score: {f1:.4f}")
425
+ f1_scores.append(f1)
426
+
427
+ scheduler.step()
428
+
429
+ # Log results
430
+ with open(log_file, mode='a', newline='') as file:
431
+ writer = csv.writer(file)
432
+ writer.writerow([task, input_type, epoch + 1, avg_train_loss, avg_val_loss, f1 if f1 is not None else "-", scheduler.get_last_lr()[0], f"{time_now}"])
433
+
434
+ # Plot training and validation losses
435
+ plt.figure(figsize=(10, 6))
436
+ plt.plot(range(1, epochs + 1), train_losses, label="Training Loss")
437
+ plt.plot(range(1, epochs + 1), val_losses, label="Validation Loss", linestyle="--")
438
+ plt.xlabel("Epochs")
439
+ plt.ylabel("Loss")
440
+ plt.title("Training and Validation Loss")
441
+ plt.legend()
442
+ plt.grid(True)
443
+ # plt.savefig(os.path.join(results_folder, "loss_curve.png"))
444
+ plt.show()
445
+
446
  return wrapper, best_model_path, train_losses, val_losses, f1_scores if task_type == "classification" else 0, attn_maps