Spaces:
Running
on
Zero
Running
on
Zero
| from torch import optim | |
| from torch.utils.data import DataLoader | |
| from torchvision.utils import save_image | |
| from tqdm import trange | |
| from Dataloader import * | |
| from .utils import image_quality | |
| from .utils.cls import CyclicLR | |
| from .utils.prepare_images import * | |
| train_folder = "./dataset/train" | |
| test_folder = "./dataset/test" | |
| img_dataset = ImageDBData( | |
| db_file="dataset/images.db", | |
| db_table="train_images_size_128_noise_1_rgb", | |
| max_images=24, | |
| ) | |
| img_data = DataLoader(img_dataset, batch_size=6, shuffle=True, num_workers=6) | |
| total_batch = len(img_data) | |
| print(len(img_dataset)) | |
| test_dataset = ImageDBData( | |
| db_file="dataset/test2.db", | |
| db_table="test_images_size_128_noise_1_rgb", | |
| max_images=None, | |
| ) | |
| num_test = len(test_dataset) | |
| test_data = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1) | |
| criteria = nn.L1Loss() | |
| model = CARN_V2( | |
| color_channels=3, | |
| mid_channels=64, | |
| conv=nn.Conv2d, | |
| single_conv_size=3, | |
| single_conv_group=1, | |
| scale=2, | |
| activation=nn.LeakyReLU(0.1), | |
| SEBlock=True, | |
| repeat_blocks=3, | |
| atrous=(1, 1, 1), | |
| ) | |
| model.total_parameters() | |
| # model.initialize_weights_xavier_uniform() | |
| # fp16 training is available in GPU only | |
| model = network_to_half(model) | |
| model = model.cuda() | |
| model.load_state_dict(torch.load("CARN_model_checkpoint.pt")) | |
| learning_rate = 1e-4 | |
| weight_decay = 1e-6 | |
| optimizer = optim.Adam( | |
| model.parameters(), lr=learning_rate, weight_decay=weight_decay, amsgrad=True | |
| ) | |
| # optimizer = optim.SGD(model.parameters(), momentum=0.9, nesterov=True, weight_decay=weight_decay, lr=learning_rate) | |
| # optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0, verbose=False) | |
| # optimizer.load_state_dict(torch.load("CARN_adam_checkpoint.pt")) | |
| last_iter = -1 # torch.load("CARN_scheduler_last_iter") | |
| scheduler = CyclicLR( | |
| optimizer, | |
| base_lr=1e-4, | |
| max_lr=1e-4, | |
| step_size=3 * total_batch, | |
| mode="triangular", | |
| last_batch_iteration=last_iter, | |
| ) | |
| train_loss = [] | |
| train_ssim = [] | |
| train_psnr = [] | |
| test_loss = [] | |
| test_ssim = [] | |
| test_psnr = [] | |
| # train_loss = torch.load("train_loss.pt") | |
| # train_ssim = torch.load("train_ssim.pt") | |
| # train_psnr = torch.load("train_psnr.pt") | |
| # | |
| # test_loss = torch.load("test_loss.pt") | |
| # test_ssim = torch.load("test_ssim.pt") | |
| # test_psnr = torch.load("test_psnr.pt") | |
| counter = 0 | |
| iteration = 2 | |
| ibar = trange( | |
| iteration, | |
| ascii=True, | |
| maxinterval=1, | |
| postfix={"avg_loss": 0, "train_ssim": 0, "test_ssim": 0}, | |
| ) | |
| for i in ibar: | |
| # batch_loss = [] | |
| # insample_ssim = [] | |
| # insample_psnr = [] | |
| for index, batch in enumerate(img_data): | |
| scheduler.batch_step() | |
| lr_img, hr_img = batch | |
| lr_img = lr_img.cuda().half() | |
| hr_img = hr_img.cuda() | |
| # model.zero_grad() | |
| optimizer.zero_grad() | |
| outputs = model.forward(lr_img) | |
| outputs = outputs.float() | |
| loss = criteria(outputs, hr_img) | |
| # loss.backward() | |
| optimizer.backward(loss) | |
| # nn.utils.clip_grad_norm_(model.parameters(), 5) | |
| optimizer.step() | |
| counter += 1 | |
| # train_loss.append(loss.item()) | |
| ssim = image_quality.msssim(outputs, hr_img).item() | |
| psnr = image_quality.psnr(outputs, hr_img).item() | |
| ibar.set_postfix( | |
| ratio=index / total_batch, | |
| loss=loss.item(), | |
| ssim=ssim, | |
| batch=index, | |
| psnr=psnr, | |
| lr=scheduler.current_lr, | |
| ) | |
| train_loss.append(loss.item()) | |
| train_ssim.append(ssim) | |
| train_psnr.append(psnr) | |
| # +++++++++++++++++++++++++++++++++++++ | |
| # save checkpoints by iterations | |
| # ------------------------------------- | |
| if (counter + 1) % 500 == 0: | |
| torch.save(model.state_dict(), "CARN_model_checkpoint.pt") | |
| torch.save(optimizer.state_dict(), "CARN_adam_checkpoint.pt") | |
| torch.save(train_loss, "train_loss.pt") | |
| torch.save(train_ssim, "train_ssim.pt") | |
| torch.save(train_psnr, "train_psnr.pt") | |
| torch.save(scheduler.last_batch_iteration, "CARN_scheduler_last_iter.pt") | |
| # +++++++++++++++++++++++++++++++++++++ | |
| # End of One Epoch | |
| # ------------------------------------- | |
| # one_ite_loss = np.mean(batch_loss) | |
| # one_ite_ssim = np.mean(insample_ssim) | |
| # one_ite_psnr = np.mean(insample_psnr) | |
| # print(f"One iteration loss {one_ite_loss}, ssim {one_ite_ssim}, psnr {one_ite_psnr}") | |
| # train_loss.append(one_ite_loss) | |
| # train_ssim.append(one_ite_ssim) | |
| # train_psnr.append(one_ite_psnr) | |
| torch.save(model.state_dict(), "CARN_model_checkpoint.pt") | |
| # torch.save(scheduler, "CARN_scheduler_optim.pt") | |
| torch.save(optimizer.state_dict(), "CARN_adam_checkpoint.pt") | |
| torch.save(train_loss, "train_loss.pt") | |
| torch.save(train_ssim, "train_ssim.pt") | |
| torch.save(train_psnr, "train_psnr.pt") | |
| # torch.save(scheduler.last_batch_iteration, "CARN_scheduler_last_iter.pt") | |
| # +++++++++++++++++++++++++++++++++++++ | |
| # Test | |
| # ------------------------------------- | |
| with torch.no_grad(): | |
| ssim = [] | |
| batch_loss = [] | |
| psnr = [] | |
| for index, test_batch in enumerate(test_data): | |
| lr_img, hr_img = test_batch | |
| lr_img = lr_img.cuda() | |
| hr_img = hr_img.cuda() | |
| lr_img_up = model(lr_img) | |
| lr_img_up = lr_img_up.float() | |
| loss = criteria(lr_img_up, hr_img) | |
| save_image([lr_img_up[0], hr_img[0]], f"check_test_imgs/{index}.png") | |
| batch_loss.append(loss.item()) | |
| ssim.append(image_quality.msssim(lr_img_up, hr_img).item()) | |
| psnr.append(image_quality.psnr(lr_img_up, hr_img).item()) | |
| test_ssim.append(np.mean(ssim)) | |
| test_loss.append(np.mean(batch_loss)) | |
| test_psnr.append(np.mean(psnr)) | |
| torch.save(test_loss, "test_loss.pt") | |
| torch.save(test_ssim, "test_ssim.pt") | |
| torch.save(test_psnr, "test_psnr.pt") | |
| # import subprocess | |
| # subprocess.call(["shutdown", "/s"]) | |