Spaces:
Runtime error
Runtime error
File size: 15,159 Bytes
b108d0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
from typing import Tuple
from pathlib import Path
from tqdm import tqdm
from accelerate import Accelerator
from accelerate.utils import set_seed
from matplotlib import cm
import numpy as np
import matplotlib.pyplot as plt
import argparse
import json
import wandb
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from models.unet import UNet
from dataset import SegmentationDataset, collate_fn
from utils import get_transform, mask_transform, EMA
from get_loss import get_composite_criterion
from models.vit import ViTSegmentation
from models.dino import DINOSegmentationModel
color_map = cm.get_cmap('tab20', 9)
fixed_colors = np.array([color_map(i)[:3] for i in range(9)]) * 255
def mask_to_color(mask: np.ndarray) -> np.ndarray:
h, w = mask.shape
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
for class_idx in range(9):
color_mask[mask == class_idx] = fixed_colors[class_idx]
return color_mask
def create_combined_image(
x: torch.Tensor,
y: torch.Tensor,
y_pred: torch.Tensor,
mean: list[float] = [0.485, 0.456, 0.406],
std: list[float] = [0.229, 0.224, 0.225]
) -> np.ndarray:
batch_size, _, height, width = x.shape
combined_height = height * 3
combined_width = width * batch_size
combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
for i in range(batch_size):
image = x[i].cpu().permute(1, 2, 0).numpy()
image = (image * std + mean).clip(0, 1)
image = (image * 255).astype(np.uint8)
true_mask = y[i].cpu().numpy()
true_mask_color = mask_to_color(true_mask)
pred_mask = y_pred[i].cpu().numpy()
pred_mask_color = mask_to_color(pred_mask)
combined_image[:height, i * width:(i + 1) * width, :] = image
combined_image[height:2 * height, i * width:(i + 1) * width, :] = true_mask_color
combined_image[2 * height:, i * width:(i + 1) * width, :] = pred_mask_color
return combined_image
def compute_metrics(y_pred: torch.Tensor, y: torch.Tensor, num_classes: int = 9) -> Tuple[float, float, float, float, float, float]:
pred_mask = y_pred.unsqueeze(-1) == torch.arange(num_classes, device=y_pred.device).reshape(1, 1, 1, -1)
target_mask = y.unsqueeze(-1) == torch.arange(num_classes, device=y.device).reshape(1, 1, 1, -1)
class_present = (target_mask.sum(dim=(0, 1, 2)) > 0).float()
tp = (pred_mask & target_mask).sum(dim=(0, 1, 2)).float()
fp = (pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float()
fn = (~pred_mask & target_mask).sum(dim=(0, 1, 2)).float()
tn = (~pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float()
overall_tp = tp.sum()
overall_fp = fp.sum()
overall_fn = fn.sum()
overall_tn = tn.sum()
precision = tp / (tp + fp).clamp(min=1e-8)
recall = tp / (tp + fn).clamp(min=1e-8)
accuracy = (tp + tn) / (tp + tn + fp + fn)
macro_precision = ((precision * class_present).sum() / class_present.sum().clamp(min=1e-8)).item()
macro_recall = ((recall * class_present).sum() / class_present.sum().clamp(min=1e-8)).item()
macro_accuracy = accuracy.mean().item()
micro_precision = (overall_tp / (overall_tp + overall_fp).clamp(min=1e-8)).item()
micro_recall = (overall_tp / (overall_tp + overall_fn).clamp(min=1e-8)).item()
global_accuracy = ((y_pred == y).sum() / (y.shape[0] * y.shape[1] * y.shape[2])).item()
return macro_precision, macro_recall, macro_accuracy, micro_precision, micro_recall, global_accuracy
def parse_args():
parser = argparse.ArgumentParser(description="Train a model on human parsing dataset")
parser.add_argument("--data-path", type=str, default="data/portraits", help="Path to the data")
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training and testing")
parser.add_argument("--pin-memory", type=bool, default=True, help="Pin memory for DataLoader")
parser.add_argument("--num-workers", type=int, default=0, help="Number of workers for DataLoader")
parser.add_argument("--num-epochs", type=int, default=15, help="Number of training epochs")
parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type")
parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer")
parser.add_argument("--max-norm", type=float, default=1.0, help="Maximum gradient norm for clipping")
parser.add_argument("--logs-dir", type=str, default="unet-logs", help="Directory for saving logs")
parser.add_argument("--model", type=str, default="unet", choices=["unet", "vit", "dino"], help="Model class name")
parser.add_argument("--losses-path", type=str, default="losses_config.json", help="Path to the losses")
parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["fp16", "bf16", "fp8", "no"], help="Value of the mixed precision")
parser.add_argument("--gradient-accumulation-steps", type=int, default=2, help="Value of the gradient accumulation steps")
parser.add_argument("--project-name", type=str, default="face_segmentation_unet", help="WandB project name")
parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights")
parser.add_argument("--log-steps", type=int, default=400, help="Number of steps for logging images")
parser.add_argument("--seed", type=int, default=42, help="Value of the seed")
return parser.parse_args()
def main() -> None:
args = parse_args()
set_seed(args.seed)
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
with open(args.losses_path, "r") as fp:
losses_config = json.load(fp)
with accelerator.main_process_first():
logs_dir = Path(args.logs_dir)
logs_dir.mkdir(exist_ok=True)
wandb.init(project=args.project_name, dir=logs_dir)
wandb.save(args.losses_path)
optimizer_class = getattr(torch.optim, args.optimizer)
if args.model == "unet":
model = UNet().to(accelerator.device)
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
elif args.model == "vit":
model = ViTSegmentation().to(accelerator.device)
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
elif args.model == "dino":
model = DINOSegmentationModel().to(accelerator.device)
optimizer = optimizer_class(model.segmentation_head.parameters(), lr=args.learning_rate)
else:
raise NotImplementedError("Incorrect model name")
transform = get_transform(model.mean, model.std)
train_dataset = SegmentationDataset(args.data_path, subset="train", transform=transform, target_transform=mask_transform)
valid_dataset = SegmentationDataset(args.data_path, subset="test", transform=transform, target_transform=mask_transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn)
criterion = get_composite_criterion(losses_config)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs * len(train_loader))
model, optimizer, train_loader, lr_scheduler = accelerator.prepare(model, optimizer, train_loader, lr_scheduler)
best_accuracy = 0
print(f"params: {sum([p.numel() for p in model.parameters()])/1e6:.2f} M")
print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M")
train_loss_ema, train_macro_precision_ema, train_macro_recall_ema, train_macro_accuracy_ema, train_micro_precision_ema, train_micro_recall_ema, train_global_accuracy_ema = EMA(), EMA(), EMA(), EMA(), EMA(), EMA(), EMA()
for epoch in range(1, args.num_epochs + 1):
model.train()
pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{args.num_epochs}")
for index, (x, y) in enumerate(pbar):
x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device)
with accelerator.accumulate(model):
with accelerator.autocast():
output = model(x)
loss = criterion(output, y)
accelerator.backward(loss)
train_loss = loss.item()
grad_norm = None
_, y_pred = output.max(dim=1)
train_macro_precision, train_macro_recall, train_macro_accuracy, train_micro_precision, train_micro_recall, train_global_accuracy = compute_metrics(y_pred, y)
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_norm).item()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
images_to_log = []
combined_image = create_combined_image(x, y, y_pred, model.mean, model.std)
images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Train, Epoch {epoch}, Batch {index})"))
wandb.log({"train_samples": images_to_log})
pbar.set_postfix({"loss": train_loss_ema(train_loss), "macro_precision": train_macro_precision_ema(train_macro_precision), "macro_recall": train_macro_recall_ema(train_macro_recall), "macro_accuracy": train_macro_accuracy_ema(train_macro_accuracy), "micro_precision": train_micro_precision_ema(train_micro_precision), "micro_recall": train_micro_recall_ema(train_micro_recall), "global_accuracy": train_global_accuracy_ema(train_global_accuracy)})
log_data = {
"train/epoch": epoch,
"train/loss": train_loss,
"train/macro_accuracy": train_macro_accuracy,
"train/learning_rate": optimizer.param_groups[0]["lr"],
"train/macro_precision": train_macro_precision,
"train/macro_recall": train_macro_recall,
"train/micro_precision": train_micro_precision,
"train/micro_recall": train_micro_recall,
"train/global_accuracy": train_global_accuracy,
}
if grad_norm is not None:
log_data["train/grad_norm"] = grad_norm
if accelerator.is_main_process:
wandb.log(log_data)
accelerator.wait_for_everyone()
model.eval()
valid_loss, valid_macro_accuracies, valid_macro_precisions, valid_macro_recalls, valid_global_accuracies, valid_micro_precisions, valid_micro_recalls = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
with torch.inference_mode():
pbar = tqdm(valid_loader, desc=f"Val epoch {epoch}/{args.num_epochs}")
for index, (x, y) in enumerate(valid_loader):
x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device)
output = model(x)
_, y_pred = output.max(dim=1)
if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
images_to_log = []
combined_image = create_combined_image(x, y, y_pred, model.mean, model.std)
images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Validation, Epoch {epoch})"))
wandb.log({"valid_samples": images_to_log})
valid_macro_precision, valid_macro_recall, valid_macro_accuracy, valid_micro_precision, valid_micro_recall, valid_global_accuracy = compute_metrics(y_pred, y)
valid_macro_precisions += valid_macro_precision
valid_macro_recalls += valid_macro_recall
valid_macro_accuracies += valid_macro_accuracy
valid_micro_precisions += valid_micro_precision
valid_micro_recalls += valid_micro_recall
valid_global_accuracies += valid_global_accuracy
loss = criterion(output, y)
valid_loss += loss.item()
valid_loss = valid_loss / len(valid_loader)
valid_macro_accuracies = valid_macro_accuracies / len(valid_loader)
valid_macro_precisions = valid_macro_precisions / len(valid_loader)
valid_macro_recalls = valid_macro_recalls / len(valid_loader)
valid_global_accuracies = valid_global_accuracies / len(valid_loader)
valid_micro_precisions = valid_micro_precisions / len(valid_loader)
valid_micro_recalls = valid_micro_recalls / len(valid_loader)
accelerator.print(f"loss: {valid_loss:.3f}, valid_macro_precision: {valid_macro_precisions:.3f}, valid_macro_recall: {valid_macro_recalls:.3f}, valid_macro_accuracy: {valid_macro_accuracies:.3f}, valid_micro_precision: {valid_micro_precisions:.3f}, valid_micro_recall: {valid_micro_recalls:.3f}, valid_global_accuracy: {valid_global_accuracies:.3f}")
if accelerator.is_main_process:
wandb.log(
{
"val/epoch": epoch,
"val/loss": valid_loss,
"val/macro_accuracy": valid_macro_accuracies,
"val/macro_precision": valid_macro_precisions,
"val/macro_recall": valid_macro_recalls,
"val/global_accuracy": valid_global_accuracies,
"val/micro_precision": valid_micro_precisions,
"val/micro_recall": valid_micro_recalls,
}
)
if valid_global_accuracies > best_accuracy:
best_accuracy = valid_global_accuracies
if args.model in ["dino", "vit"]:
accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-best.pth")
else:
accelerator.save(model.state_dict(), logs_dir / f"checkpoint-best.pth")
accelerator.print(f"new best_accuracy {best_accuracy}, {epoch=}")
if epoch % args.save_frequency == 0:
if args.model in ["dino", "vit"]:
accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
else:
accelerator.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
accelerator.wait_for_everyone()
accelerator.wait_for_everyone()
wandb.finish()
if __name__ == "__main__":
main() |