Spaces:
Runtime error
Runtime error
| import torch | |
| import cv2 | |
| import random | |
| import os.path as osp | |
| import argparse | |
| from scipy.stats import spearmanr, pearsonr | |
| from scipy.stats.stats import kendalltau as kendallr | |
| import numpy as np | |
| from time import time | |
| from tqdm import tqdm | |
| import pickle | |
| import math | |
| import wandb | |
| import yaml | |
| from collections import OrderedDict | |
| from functools import reduce | |
| from thop import profile | |
| import copy | |
| import cover.models as models | |
| import cover.datasets as datasets | |
| def train_test_split(dataset_path, ann_file, ratio=0.8, seed=42): | |
| random.seed(seed) | |
| print(seed) | |
| video_infos = [] | |
| with open(ann_file, "r") as fin: | |
| for line in fin.readlines(): | |
| line_split = line.strip().split(",") | |
| filename, _, _, label = line_split | |
| label = float(label) | |
| filename = osp.join(dataset_path, filename) | |
| video_infos.append(dict(filename=filename, label=label)) | |
| random.shuffle(video_infos) | |
| return ( | |
| video_infos[: int(ratio * len(video_infos))], | |
| video_infos[int(ratio * len(video_infos)) :], | |
| ) | |
| def rank_loss(y_pred, y): | |
| ranking_loss = torch.nn.functional.relu( | |
| (y_pred - y_pred.t()) * torch.sign((y.t() - y)) | |
| ) | |
| scale = 1 + torch.max(ranking_loss) | |
| return ( | |
| torch.sum(ranking_loss) / y_pred.shape[0] / (y_pred.shape[0] - 1) / scale | |
| ).float() | |
| def gaussian(y, eps=1e-8): | |
| return (y - y.mean()) / (y.std() + 1e-8) | |
| def plcc_loss(y_pred, y): | |
| sigma_hat, m_hat = torch.std_mean(y_pred, unbiased=False) | |
| y_pred = (y_pred - m_hat) / (sigma_hat + 1e-8) | |
| sigma, m = torch.std_mean(y, unbiased=False) | |
| y = (y - m) / (sigma + 1e-8) | |
| loss0 = torch.nn.functional.mse_loss(y_pred, y) / 4 | |
| rho = torch.mean(y_pred * y) | |
| loss1 = torch.nn.functional.mse_loss(rho * y_pred, y) / 4 | |
| return ((loss0 + loss1) / 2).float() | |
| def rescaled_l2_loss(y_pred, y): | |
| y_pred_rs = (y_pred - y_pred.mean()) / y_pred.std() | |
| y_rs = (y - y.mean()) / (y.std() + eps) | |
| return torch.nn.functional.mse_loss(y_pred_rs, y_rs) | |
| def rplcc_loss(y_pred, y, eps=1e-8): | |
| ## Literally (1 - PLCC) / 2 | |
| y_pred, y = gaussian(y_pred), gaussian(y) | |
| cov = torch.sum(y_pred * y) / y_pred.shape[0] | |
| # std = (torch.std(y_pred) + eps) * (torch.std(y) + eps) | |
| return (1 - cov) / 2 | |
| def self_similarity_loss(f, f_hat, f_hat_detach=False): | |
| if f_hat_detach: | |
| f_hat = f_hat.detach() | |
| return 1 - torch.nn.functional.cosine_similarity(f, f_hat, dim=1).mean() | |
| def contrastive_similarity_loss(f, f_hat, f_hat_detach=False, eps=1e-8): | |
| if f_hat_detach: | |
| f_hat = f_hat.detach() | |
| intra_similarity = torch.nn.functional.cosine_similarity(f, f_hat, dim=1).mean() | |
| cross_similarity = torch.nn.functional.cosine_similarity(f, f_hat, dim=0).mean() | |
| return (1 - intra_similarity) / (1 - cross_similarity + eps) | |
| def rescale(pr, gt=None): | |
| if gt is None: | |
| pr = (pr - np.mean(pr)) / np.std(pr) | |
| else: | |
| pr = ((pr - np.mean(pr)) / np.std(pr)) * np.std(gt) + np.mean(gt) | |
| return pr | |
| sample_types = ["semantic", "technical", "aesthetic"] | |
| def finetune_epoch( | |
| ft_loader, | |
| model, | |
| model_ema, | |
| optimizer, | |
| scheduler, | |
| device, | |
| epoch=-1, | |
| need_upsampled=False, | |
| need_feat=False, | |
| need_fused=False, | |
| need_separate_sup=True, | |
| ): | |
| model.train() | |
| for i, data in enumerate(tqdm(ft_loader, desc=f"Training in epoch {epoch}")): | |
| optimizer.zero_grad() | |
| video = {} | |
| for key in sample_types: | |
| if key in data: | |
| video[key] = data[key].to(device) | |
| y = data["gt_label"].float().detach().to(device).unsqueeze(-1) | |
| scores = model(video, inference=False, reduce_scores=False) | |
| if len(scores) > 1: | |
| y_pred = reduce(lambda x, y: x + y, scores) | |
| else: | |
| y_pred = scores[0] | |
| y_pred = y_pred.mean((-3, -2, -1)) | |
| frame_inds = data["frame_inds"] | |
| loss = 0 # p_loss + 0.3 * r_loss | |
| if need_separate_sup: | |
| p_loss_a = plcc_loss(scores[0].mean((-3, -2, -1)), y) | |
| p_loss_b = plcc_loss(scores[1].mean((-3, -2, -1)), y) | |
| p_loss_c = plcc_loss(scores[2].mean((-3, -2, -1)), y) | |
| r_loss_a = rank_loss(scores[0].mean((-3, -2, -1)), y) | |
| r_loss_b = rank_loss(scores[1].mean((-3, -2, -1)), y) | |
| r_loss_c = rank_loss(scores[2].mean((-3, -2, -1)), y) | |
| loss += ( | |
| p_loss_a + p_loss_b + p_loss_c + 0.3 * r_loss_a + 0.3 * r_loss_b + 0.3 * r_loss_c | |
| ) # + 0.2 * o_loss | |
| wandb.log( | |
| { | |
| "train/plcc_loss_a": p_loss_a.item(), | |
| "train/plcc_loss_b": p_loss_b.item(), | |
| "train/plcc_loss_c": p_loss_c.item(), | |
| } | |
| ) | |
| wandb.log( | |
| {"train/total_loss": loss.item(),} | |
| ) | |
| loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| # ft_loader.dataset.refresh_hypers() | |
| if model_ema is not None: | |
| model_params = dict(model.named_parameters()) | |
| model_ema_params = dict(model_ema.named_parameters()) | |
| for k in model_params.keys(): | |
| model_ema_params[k].data.mul_(0.999).add_( | |
| model_params[k].data, alpha=1 - 0.999 | |
| ) | |
| model.eval() | |
| def profile_inference(inf_set, model, device): | |
| video = {} | |
| data = inf_set[0] | |
| for key in sample_types: | |
| if key in data: | |
| video[key] = data[key].to(device).unsqueeze(0) | |
| with torch.no_grad(): | |
| flops, params = profile(model, (video,)) | |
| print( | |
| f"The FLOps of the Variant is {flops/1e9:.1f}G, with Params {params/1e6:.2f}M." | |
| ) | |
| def inference_set( | |
| inf_loader, | |
| model, | |
| device, | |
| best_, | |
| save_model=False, | |
| suffix="s", | |
| save_name="divide", | |
| save_type="head", | |
| ): | |
| results = [] | |
| best_s, best_p, best_k, best_r = best_ | |
| for i, data in enumerate(tqdm(inf_loader, desc="Validating")): | |
| result = dict() | |
| video, video_up = {}, {} | |
| for key in sample_types: | |
| if key in data: | |
| video[key] = data[key].to(device) | |
| ## Reshape into clips | |
| b, c, t, h, w = video[key].shape | |
| video[key] = ( | |
| video[key] | |
| .reshape( | |
| b, c, data["num_clips"][key], t // data["num_clips"][key], h, w | |
| ) | |
| .permute(0, 2, 1, 3, 4, 5) | |
| .reshape( | |
| b * data["num_clips"][key], c, t // data["num_clips"][key], h, w | |
| ) | |
| ) | |
| if key + "_up" in data: | |
| video_up[key] = data[key + "_up"].to(device) | |
| ## Reshape into clips | |
| b, c, t, h, w = video_up[key].shape | |
| video_up[key] = ( | |
| video_up[key] | |
| .reshape(b, c, data["num_clips"], t // data["num_clips"], h, w) | |
| .permute(0, 2, 1, 3, 4, 5) | |
| .reshape(b * data["num_clips"], c, t // data["num_clips"], h, w) | |
| ) | |
| # .unsqueeze(0) | |
| with torch.no_grad(): | |
| result["pr_labels"] = model(video, reduce_scores=True).cpu().numpy() | |
| if len(list(video_up.keys())) > 0: | |
| result["pr_labels_up"] = model(video_up).cpu().numpy() | |
| result["gt_label"] = data["gt_label"].item() | |
| del video, video_up | |
| results.append(result) | |
| ## generate the demo video for video quality localization | |
| gt_labels = [r["gt_label"] for r in results] | |
| pr_labels = [np.mean(r["pr_labels"][:]) for r in results] | |
| pr_labels = rescale(pr_labels, gt_labels) | |
| s = spearmanr(gt_labels, pr_labels)[0] | |
| p = pearsonr(gt_labels, pr_labels)[0] | |
| k = kendallr(gt_labels, pr_labels)[0] | |
| r = np.sqrt(((gt_labels - pr_labels) ** 2).mean()) | |
| wandb.log( | |
| { | |
| f"val_{suffix}/SRCC-{suffix}": s, | |
| f"val_{suffix}/PLCC-{suffix}": p, | |
| f"val_{suffix}/KRCC-{suffix}": k, | |
| f"val_{suffix}/RMSE-{suffix}": r, | |
| } | |
| ) | |
| del results, result # , video, video_up | |
| torch.cuda.empty_cache() | |
| if s + p > best_s + best_p and save_model: | |
| state_dict = model.state_dict() | |
| if save_type == "head": | |
| head_state_dict = OrderedDict() | |
| for key, v in state_dict.items(): | |
| if "backbone" in key: | |
| continue | |
| else: | |
| head_state_dict[key] = v | |
| print("Following keys are saved :", head_state_dict.keys()) | |
| torch.save( | |
| {"state_dict": head_state_dict, "validation_results": best_,}, | |
| f"pretrained_weights/{save_name}_{suffix}_finetuned.pth", | |
| ) | |
| else: | |
| torch.save( | |
| {"state_dict": state_dict, "validation_results": best_,}, | |
| f"pretrained_weights/{save_name}_{suffix}_finetuned.pth", | |
| ) | |
| best_s, best_p, best_k, best_r = ( | |
| max(best_s, s), | |
| max(best_p, p), | |
| max(best_k, k), | |
| min(best_r, r), | |
| ) | |
| wandb.log( | |
| { | |
| f"val_{suffix}/best_SRCC-{suffix}": best_s, | |
| f"val_{suffix}/best_PLCC-{suffix}": best_p, | |
| f"val_{suffix}/best_KRCC-{suffix}": best_k, | |
| f"val_{suffix}/best_RMSE-{suffix}": best_r, | |
| } | |
| ) | |
| print( | |
| f"For {len(inf_loader)} videos, \nthe accuracy of the model: [{suffix}] is as follows:\n SROCC: {s:.4f} best: {best_s:.4f} \n PLCC: {p:.4f} best: {best_p:.4f} \n KROCC: {k:.4f} best: {best_k:.4f} \n RMSE: {r:.4f} best: {best_r:.4f}." | |
| ) | |
| return best_s, best_p, best_k, best_r | |
| # torch.save(results, f'{args.save_dir}/results_{dataset.lower()}_s{32}*{32}_ens{args.famount}.pkl') | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-o", "--opt", type=str, default="cover.yml", help="the option file" | |
| ) | |
| parser.add_argument( | |
| "-t", "--target_set", type=str, default="val-kv1k", help="target_set" | |
| ) | |
| parser.add_argument('-n', "--name", type=str, default="COVER_TMP", help='model name to save checkpoint') | |
| parser.add_argument('-uh', "--usehead", type=int, default=0, help='wheather to load header weight from checkpoint') | |
| args = parser.parse_args() | |
| with open(args.opt, "r") as f: | |
| opt = yaml.safe_load(f) | |
| print(opt) | |
| ## adaptively choose the device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ## defining model and loading checkpoint | |
| bests_ = [] | |
| if opt.get("split_seed", -1) > 0: | |
| num_splits = 10 | |
| else: | |
| num_splits = 1 | |
| print(opt["split_seed"]) | |
| for split in range(10): | |
| model = getattr(models, opt["model"]["type"])(**opt["model"]["args"]).to(device) | |
| if opt.get("split_seed", -1) > 0: | |
| opt["data"]["train"] = copy.deepcopy(opt["data"][args.target_set]) | |
| opt["data"]["eval"] = copy.deepcopy(opt["data"][args.target_set]) | |
| split_duo = train_test_split( | |
| opt["data"][args.target_set]["args"]["data_prefix"], | |
| opt["data"][args.target_set]["args"]["anno_file"], | |
| seed=opt["split_seed"] * (split + 1), | |
| ) | |
| ( | |
| opt["data"]["train"]["args"]["anno_file"], | |
| opt["data"]["eval"]["args"]["anno_file"], | |
| ) = split_duo | |
| opt["data"]["train"]["args"]["sample_types"]["technical"]["num_clips"] = 1 | |
| train_datasets = {} | |
| for key in opt["data"]: | |
| if key.startswith("train"): | |
| train_dataset = getattr(datasets, opt["data"][key]["type"])( | |
| opt["data"][key]["args"] | |
| ) | |
| train_datasets[key] = train_dataset | |
| print(len(train_dataset.video_infos)) | |
| train_loaders = {} | |
| for key, train_dataset in train_datasets.items(): | |
| train_loaders[key] = torch.utils.data.DataLoader( | |
| train_dataset, | |
| batch_size=opt["batch_size"], | |
| num_workers=opt["num_workers"], | |
| shuffle=True, | |
| ) | |
| val_datasets = {} | |
| for key in opt["data"]: | |
| if key.startswith("eval"): | |
| val_dataset = getattr(datasets, opt["data"][key]["type"])( | |
| opt["data"][key]["args"] | |
| ) | |
| print(len(val_dataset.video_infos)) | |
| val_datasets[key] = val_dataset | |
| val_loaders = {} | |
| for key, val_dataset in val_datasets.items(): | |
| val_loaders[key] = torch.utils.data.DataLoader( | |
| val_dataset, | |
| batch_size=1, | |
| num_workers=opt["num_workers"], | |
| pin_memory=True, | |
| ) | |
| run = wandb.init( | |
| project=opt["wandb"]["project_name"], | |
| name=opt["name"] + f"_target_{args.target_set}_split_{split}" | |
| if num_splits > 1 | |
| else opt["name"], | |
| reinit=True, | |
| settings=wandb.Settings(start_method="thread"), | |
| ) | |
| state_dict = torch.load(opt["test_load_path"], map_location=device) | |
| # Load fine_tuned header from checkpoint | |
| if args.usehead: | |
| state_dict_head = torch.load(opt["test_load_header_path"], map_location=device) | |
| for key in state_dict_head['state_dict'].keys(): | |
| state_dict[key] = state_dict_head['state_dict'][key] | |
| # Allowing empty head weight | |
| model.load_state_dict(state_dict, strict=False) | |
| if opt["ema"]: | |
| from copy import deepcopy | |
| model_ema = deepcopy(model) | |
| else: | |
| model_ema = None | |
| # profile_inference(val_dataset, model, device) | |
| # finetune the model | |
| param_groups = [] | |
| for key, value in dict(model.named_children()).items(): | |
| if "backbone" in key: | |
| param_groups += [ | |
| { | |
| "params": value.parameters(), | |
| "lr": opt["optimizer"]["lr"] | |
| * opt["optimizer"]["backbone_lr_mult"], | |
| } | |
| ] | |
| else: | |
| param_groups += [ | |
| {"params": value.parameters(), "lr": opt["optimizer"]["lr"]} | |
| ] | |
| optimizer = torch.optim.AdamW( | |
| lr=opt["optimizer"]["lr"], | |
| params=param_groups, | |
| weight_decay=opt["optimizer"]["wd"], | |
| ) | |
| warmup_iter = 0 | |
| for train_loader in train_loaders.values(): | |
| warmup_iter += int(opt["warmup_epochs"] * len(train_loader)) | |
| max_iter = int((opt["num_epochs"] + opt["l_num_epochs"]) * len(train_loader)) | |
| lr_lambda = ( | |
| lambda cur_iter: cur_iter / warmup_iter | |
| if cur_iter <= warmup_iter | |
| else 0.5 * (1 + math.cos(math.pi * (cur_iter - warmup_iter) / max_iter)) | |
| ) | |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda,) | |
| bests = {} | |
| bests_n = {} | |
| for key in val_loaders: | |
| bests[key] = -1, -1, -1, 1000 | |
| bests_n[key] = -1, -1, -1, 1000 | |
| for key, value in dict(model.named_children()).items(): | |
| if "backbone" in key: | |
| for param in value.parameters(): | |
| param.requires_grad = False | |
| for epoch in range(opt["l_num_epochs"]): | |
| print(f"Linear Epoch {epoch}:") | |
| for key, train_loader in train_loaders.items(): | |
| finetune_epoch( | |
| train_loader, | |
| model, | |
| model_ema, | |
| optimizer, | |
| scheduler, | |
| device, | |
| epoch, | |
| opt.get("need_upsampled", False), | |
| opt.get("need_feat", False), | |
| opt.get("need_fused", False), | |
| ) | |
| for key in val_loaders: | |
| bests[key] = inference_set( | |
| val_loaders[key], | |
| model_ema if model_ema is not None else model, | |
| device, | |
| bests[key], | |
| save_model=opt["save_model"], | |
| save_name=args.name + "_head_" + args.target_set + f"_{split}", | |
| suffix=key + "_s", | |
| ) | |
| if model_ema is not None: | |
| bests_n[key] = inference_set( | |
| val_loaders[key], | |
| model, | |
| device, | |
| bests_n[key], | |
| save_model=opt["save_model"], | |
| save_name=args.name | |
| + "_head_" | |
| + args.target_set | |
| + f"_{split}", | |
| suffix=key + "_n", | |
| ) | |
| else: | |
| bests_n[key] = bests[key] | |
| if opt["l_num_epochs"] >= 0: | |
| for key in val_loaders: | |
| print( | |
| f"""For the linear transfer process on {key} with {len(val_loaders[key])} videos, | |
| the best validation accuracy of the model-s is as follows: | |
| SROCC: {bests[key][0]:.4f} | |
| PLCC: {bests[key][1]:.4f} | |
| KROCC: {bests[key][2]:.4f} | |
| RMSE: {bests[key][3]:.4f}.""" | |
| ) | |
| print( | |
| f"""For the linear transfer process on {key} with {len(val_loaders[key])} videos, | |
| the best validation accuracy of the model-n is as follows: | |
| SROCC: {bests_n[key][0]:.4f} | |
| PLCC: {bests_n[key][1]:.4f} | |
| KROCC: {bests_n[key][2]:.4f} | |
| RMSE: {bests_n[key][3]:.4f}.""" | |
| ) | |
| for key, value in dict(model.named_children()).items(): | |
| if "backbone" in key: | |
| for param in value.parameters(): | |
| param.requires_grad = True | |
| for epoch in range(opt["num_epochs"]): | |
| print(f"End-to-end Epoch {epoch}:") | |
| for key, train_loader in train_loaders.items(): | |
| finetune_epoch( | |
| train_loader, | |
| model, | |
| model_ema, | |
| optimizer, | |
| scheduler, | |
| device, | |
| epoch, | |
| opt.get("need_upsampled", False), | |
| opt.get("need_feat", False), | |
| opt.get("need_fused", False), | |
| ) | |
| for key in val_loaders: | |
| bests[key] = inference_set( | |
| val_loaders[key], | |
| model_ema if model_ema is not None else model, | |
| device, | |
| bests[key], | |
| save_model=opt["save_model"], | |
| save_name=args.name + "_head_" + args.target_set + f"_{split}", | |
| suffix=key + "_s", | |
| save_type="full", | |
| ) | |
| if model_ema is not None: | |
| bests_n[key] = inference_set( | |
| val_loaders[key], | |
| model, | |
| device, | |
| bests_n[key], | |
| save_model=opt["save_model"], | |
| save_name=args.name | |
| + "_head_" | |
| + args.target_set | |
| + f"_{split}", | |
| suffix=key + "_n", | |
| save_type="full", | |
| ) | |
| else: | |
| bests_n[key] = bests[key] | |
| if opt["num_epochs"] >= 0: | |
| for key in val_loaders: | |
| print( | |
| f"""For the end-to-end transfer process on {key} with {len(val_loaders[key])} videos, | |
| the best validation accuracy of the model-s is as follows: | |
| SROCC: {bests[key][0]:.4f} | |
| PLCC: {bests[key][1]:.4f} | |
| KROCC: {bests[key][2]:.4f} | |
| RMSE: {bests[key][3]:.4f}.""" | |
| ) | |
| print( | |
| f"""For the end-to-end transfer process on {key} with {len(val_loaders[key])} videos, | |
| the best validation accuracy of the model-n is as follows: | |
| SROCC: {bests_n[key][0]:.4f} | |
| PLCC: {bests_n[key][1]:.4f} | |
| KROCC: {bests_n[key][2]:.4f} | |
| RMSE: {bests_n[key][3]:.4f}.""" | |
| ) | |
| for key, value in dict(model.named_children()).items(): | |
| if "backbone" in key: | |
| for param in value.parameters(): | |
| param.requires_grad = True | |
| run.finish() | |
| if __name__ == "__main__": | |
| main() | |