"""
Testing base class.
"""
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch

import numpy as np
import math

from . import flow_transforms

class TesterBase():
    def __init__(self, args):
        cudnn.benchmark = True

        self.mean_values = torch.tensor([0.411, 0.432, 0.45]).view(1, 3, 1, 1).to(args.device)

        self.args = args

    def init_dataset(self):          
        if self.args.dataset == 'BSD500':
            from ..data import BSD500
            # ==========  Data loading code ==============
            input_transform = transforms.Compose([
                flow_transforms.ArrayToTensor(),
                transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
                transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1])
            ])

            val_input_transform = transforms.Compose([
                flow_transforms.ArrayToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
                transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1])
            ])

            target_transform = transforms.Compose([
                flow_transforms.ArrayToTensor(),
            ])

            co_transform = flow_transforms.Compose([
                    flow_transforms.CenterCrop((self.args.train_img_height , self.args.train_img_width)),
            ])
            print("=> loading img pairs from '{}'".format(self.args.data))
            if self.args.crop_img == 0:
                train_set, val_set = BSD500(self.args.data,
                                            transform=input_transform,
                                            val_transform = val_input_transform,
                                            target_transform=target_transform)
            else:
                train_set, val_set = BSD500(self.args.data,
                                            transform=input_transform,
                                            val_transform = val_input_transform,
                                            target_transform=target_transform,
                                            co_transform=co_transform)
            print('{} samples found, {} train samples and {} val samples '.format(len(val_set)+len(train_set), len(train_set), len(val_set)))

        
            self.train_loader = torch.utils.data.DataLoader(
                                train_set, batch_size=self.args.batch_size,
                                num_workers=self.args.workers, pin_memory=True, shuffle=False, drop_last=True)
        elif self.args.dataset == 'texture':
            from ..data.texture_v3 import Dataset
            dataset = Dataset(self.args.data_path, crop_size=self.args.train_img_height, test=True)
            self.train_loader = torch.utils.data.DataLoader(dataset     = dataset,
                                                            batch_size  = self.args.batch_size,
                                                            num_workers = self.args.workers,
                                                            shuffle     = False,
                                                            drop_last   = True)
        else:
            from basicsr.data import create_dataloader, create_dataset
            opt = {}
            opt['dist'] = False
            opt['phase'] = 'train'

            opt['name'] = 'DIV2K'
            opt['type'] = 'PairedImageDataset'
            opt['dataroot_gt'] = self.args.HR_dir
            opt['dataroot_lq'] = self.args.LR_dir
            opt['filename_tmpl'] = '{}'
            opt['io_backend'] = dict(type='disk')

            opt['gt_size'] = self.args.train_img_height
            opt['use_flip'] = True
            opt['use_rot'] = True

            opt['use_shuffle'] = True
            opt['num_worker_per_gpu'] = self.args.workers
            opt['batch_size_per_gpu'] = self.args.batch_size
            opt['scale'] = int(self.args.ratio)

            opt['dataset_enlarge_ratio'] = 1
            dataset = create_dataset(opt)
            self.train_loader = create_dataloader(
            dataset, opt, num_gpu=1, dist=opt['dist'], sampler=None)
    
    def init_testing(self):
        self.init_constant()
        self.init_dataset()
        self.define_model()

    def init_constant(self):
        return

    def define_model(self):
        raise NotImplementedError
    
    def display(self):
        raise NotImplementedError
    
    def forward(self, iteration):
        raise NotImplementedError
    
    def test(self):
        args = self.args
        for iteration, data in enumerate(self.train_loader):
            print("Iteration: {}.".format(iteration))
            if args.dataset == 'BSD500':
                image = data[0].cuda()
                self.label = data[1].cuda()
                self.gt = None
            elif args.dataset == 'texture':
                image = data[0].cuda()
                self.image2 = data[1].cuda()
            else:
                image = data['lq'].cuda()
                self.gt = data['gt'].cuda()
            image = image.cuda()
            self.image = image
            self.forward()
            self.display(iteration)
            if iteration > args.niteration:
                break