import torch from torch.cuda.amp import autocast import numpy as np import time import os import yaml from matplotlib import pyplot as plt import glob from collections import OrderedDict from tqdm import tqdm import torch.distributed as dist class Trainer(object): """ A class that encapsulates the training loop for a PyTorch model. """ def __init__(self, model, optimizer, criterion, train_dataloader, device, world_size=1, output_dim=2, scheduler=None, val_dataloader=None, max_iter=np.inf, scaler=None, grad_clip=False, exp_num=None, log_path=None, exp_name=None, plot_every=None, cos_inc=False, range_update=None, accumulation_step=1, wandb_log=False, num_quantiles=1, update_func=lambda x: x): self.model = model self.optimizer = optimizer self.criterion = criterion self.scaler = scaler self.grad_clip = grad_clip self.cos_inc = cos_inc self.output_dim = output_dim self.scheduler = scheduler self.train_dl = train_dataloader self.val_dl = val_dataloader self.train_sampler = self.get_sampler_from_dataloader(train_dataloader) self.val_sampler = self.get_sampler_from_dataloader(val_dataloader) self.max_iter = max_iter self.device = device self.world_size = world_size self.exp_num = exp_num self.exp_name = exp_name self.log_path = log_path self.best_state_dict = None self.plot_every = plot_every self.logger = None self.range_update = range_update self.accumulation_step = accumulation_step self.wandb = wandb_log self.num_quantiles = num_quantiles self.update_func = update_func # if log_path is not None: # self.logger =SummaryWriter(f'{self.log_path}/exp{self.exp_num}') # # print(f"logger path: {self.log_path}/exp{self.exp_num}") # print("logger is: ", self.logger) def get_sampler_from_dataloader(self, dataloader): if hasattr(dataloader, 'sampler'): if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler): return dataloader.sampler elif hasattr(dataloader.sampler, 'sampler'): return dataloader.sampler.sampler if hasattr(dataloader, 'batch_sampler') and hasattr(dataloader.batch_sampler, 'sampler'): return dataloader.batch_sampler.sampler return None def fit(self, num_epochs, device, early_stopping=None, only_p=False, best='loss', conf=False): """ Fits the model for the given number of epochs. """ min_loss = np.inf best_acc = 0 train_loss, val_loss, = [], [] train_acc, val_acc = [], [] lrs = [] # self.optim_params['lr_history'] = [] epochs_without_improvement = 0 # main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu' main_proccess = True # change in a ddp setting print(f"Starting training for {num_epochs} epochs") print("is main process: ", main_proccess, flush=True) global_time = time.time() self.epoch = 0 for epoch in range(num_epochs): self.epoch = epoch start_time = time.time() plot = (self.plot_every is not None) and (epoch % self.plot_every == 0) t_loss, t_acc = self.train_epoch(device, epoch=epoch) t_loss_mean = np.nanmean(t_loss) train_loss.extend(t_loss) global_train_accuracy, global_train_loss = self.process_loss(t_acc, t_loss_mean) if main_proccess: # Only perform this on the master GPU train_acc.append(global_train_accuracy.mean().item()) v_loss, v_acc = self.eval_epoch(device, epoch=epoch) v_loss_mean = np.nanmean(v_loss) val_loss.extend(v_loss) global_val_accuracy, global_val_loss = self.process_loss(v_acc, v_loss_mean) if main_proccess: # Only perform this on the master GPU val_acc.append(global_val_accuracy.mean().item()) current_objective = global_val_loss if best == 'loss' else global_val_accuracy.mean() improved = False if best == 'loss': if current_objective < min_loss: min_loss = current_objective improved = True else: if current_objective > best_acc: best_acc = current_objective improved = True if improved: model_name = f'{self.log_path}/{self.exp_num}/{self.exp_name}.pth' print(f"saving model at {model_name}...") torch.save(self.model.state_dict(), model_name) self.best_state_dict = self.model.state_dict() epochs_without_improvement = 0 else: epochs_without_improvement += 1 current_lr = self.optimizer.param_groups[0]['lr'] if self.scheduler is None \ else self.scheduler.get_last_lr()[0] lrs.append(current_lr) print(f'Epoch {epoch}, lr {current_lr}, Train Loss: {global_train_loss:.6f}, Val Loss:'\ f'{global_val_loss:.6f}, Train Acc: {global_train_accuracy.round(decimals=4).tolist()}, '\ f'Val Acc: {global_val_accuracy.round(decimals=4).tolist()},'\ f'Time: {time.time() - start_time:.2f}s, Total Time: {(time.time() - global_time)/3600} hr', flush=True) if epoch % 10 == 0: print(os.system('nvidia-smi')) if epochs_without_improvement == early_stopping: print('early stopping!', flush=True) break if time.time() - global_time > (23.83 * 3600): print("time limit reached") break return {"num_epochs":num_epochs, "train_loss": train_loss, "val_loss": val_loss, "train_acc": train_acc, "val_acc": val_acc, "lrs": lrs} def process_loss(self, acc, loss_mean): if torch.cuda.is_available() and torch.distributed.is_initialized(): global_accuracy = torch.tensor(acc).cuda() # Convert accuracy to a tensor on the GPU torch.distributed.reduce(global_accuracy, dst=0, op=torch.distributed.ReduceOp.SUM) global_loss = torch.tensor(loss_mean).cuda() # Convert loss to a tensor on the GPU torch.distributed.reduce(global_loss, dst=0, op=torch.distributed.ReduceOp.SUM) # Divide both loss and accuracy by world size world_size = torch.distributed.get_world_size() global_loss /= world_size global_accuracy /= world_size else: global_loss = torch.tensor(loss_mean) global_accuracy = torch.tensor(acc) return global_accuracy, global_loss def load_best_model(self, to_ddp=True, from_ddp=True): data_dir = f'{self.log_path}/exp{self.exp_num}' # data_dir = f'{self.log_path}/exp29' # for debugging state_dict_files = glob.glob(data_dir + '/*.pth') print("loading model from ", state_dict_files[-1]) state_dict = torch.load(state_dict_files[-1]) if to_ddp else torch.load(state_dict_files[0],map_location=self.device) if from_ddp: print("loading distributed model") # Remove "module." from keys new_state_dict = OrderedDict() for key, value in state_dict.items(): if key.startswith('module.'): while key.startswith('module.'): key = key[7:] new_state_dict[key] = value state_dict = new_state_dict # print("state_dict: ", state_dict.keys()) # print("model: ", self.model.state_dict().keys()) self.model.load_state_dict(state_dict, strict=False) def check_gradients(self): for name, param in self.model.named_parameters(): if param.grad is not None: grad_norm = param.grad.norm().item() if grad_norm > 10: print(f"Large gradient in {name}: {grad_norm}") def train_epoch(self, device, epoch): """ Trains the model for one epoch. """ if self.train_sampler is not None: try: self.train_sampler.set_epoch(epoch) except AttributeError: pass self.model.train() train_loss = [] train_acc = 0 total = 0 all_accs = torch.zeros(self.output_dim, device=device) pbar = tqdm(self.train_dl) for i, batch in enumerate(pbar): if self.optimizer is not None: self.optimizer.zero_grad() loss, acc , y = self.train_batch(batch, i, device) train_loss.append(loss.item()) all_accs = all_accs + acc total += len(y) pbar.set_description(f"train_acc: {acc}, train_loss: {loss.item()}") if i > self.max_iter: break print("number of train_accs: ", train_acc) return train_loss, all_accs/total def train_batch(self, batch, batch_idx, device): x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label'] x = x.to(device).float() fft = fft.to(device).float() y = y.to(device).float() x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) y_pred = self.model(x_fft).squeeze() loss = self.criterion(y_pred, y) loss.backward() self.optimizer.step() if self.scheduler is not None: self.scheduler.step() # get predicted classes probs = torch.sigmoid(y_pred) cls_pred = (probs > 0.5).float() acc = (cls_pred == y).sum() return loss, acc, y def eval_epoch(self, device, epoch): """ Evaluates the model for one epoch. """ self.model.eval() val_loss = [] val_acc = 0 total = 0 all_accs = torch.zeros(self.output_dim, device=device) pbar = tqdm(self.val_dl) for i,batch in enumerate(pbar): loss, acc, y = self.eval_batch(batch, i, device) val_loss.append(loss.item()) all_accs = all_accs + acc total += len(y) pbar.set_description(f"val_acc: {acc}, val_loss: {loss.item()}") if i > self.max_iter: break return val_loss, all_accs/total def eval_batch(self, batch, batch_idx, device): x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label'] x = x.to(device).float() fft = fft.to(device).float() x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) y = y.to(device).float() with torch.no_grad(): y_pred = self.model(x_fft).squeeze() loss = self.criterion(y_pred.squeeze(), y) probs = torch.sigmoid(y_pred) cls_pred = (probs > 0.5).float() acc = (cls_pred == y).sum() return loss, acc, y def predict(self, test_dataloader, device): """ Returns the predictions of the model on the given dataset. """ self.model.eval() total = 0 all_accs = 0 predictions = [] true_labels = [] pbar = tqdm(test_dataloader) for i,batch in enumerate(pbar): x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label'] x = x.to(device).float() fft = fft.to(device).float() x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) y = y.to(device).float() with torch.no_grad(): y_pred = self.model(x_fft).squeeze() loss = self.criterion(y_pred, y) probs = torch.sigmoid(y_pred) cls_pred = (probs > 0.5).float() acc = (cls_pred == y).sum() predictions.extend(cls_pred.cpu().numpy()) true_labels.extend(y.cpu().numpy()) all_accs += acc total += len(y) pbar.set_description("acc: {:.4f}".format(acc)) if i > self.max_iter: break return predictions, true_labels, all_accs/total