import gradio as gr
from load_model import extract_sel_mean_std_bias_assignemnt
from pathlib import Path
from architectures.model_mapping import get_model
from configs.dataset_params import dataset_constants
import torch
import torchvision.transforms as transforms
import pandas as pd
import cv2
import numpy as np
from PIL import Image
from get_data import get_augmentation
from configs.dataset_params import normalize_params
import random
from evaluation.diversity import MultiKCrossChannelMaxPooledSum

def overlapping_features_on_input(model,output, feature_maps, input, target):
    W=model.linear.layer.weight
    feature_maps=feature_maps.detach().cpu().numpy().squeeze()
    print("feature_maps",feature_maps.shape)

    if target !=None:
     label=target-1
    else:
     output=output.detach().cpu().numpy()
     label=np.argmax(output)

    Interpretable_Selection= W[label,:]
    print("W",Interpretable_Selection)
    input_np=np.array(input)
    h,w= input.shape[:2]
    print("h,w:",h,w)
    Interpretable_Features=[]
    
    input_np=cv2.resize(input_np,(448,448))
    Feature_image_list=[input_np]

    # color_id=0 #set each feature to singel color
    # COLOR=['R','G','B','Y','P','C']


    for S in range(len(Interpretable_Selection)):
        if Interpretable_Selection[S] != 0:
               Interpretable_Features.append(feature_maps[S])
               Feature_image=cv2.resize(feature_maps[S],(448,448))
               Feature_image=np.uint((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)) * 255)
               Feature_image=Feature_image.astype(np.uint8)


               #set each feature to singel color

            #    if color_id>len(COLOR)-1:
            #        color_id=color_id%len(COLOR)

            #    color=COLOR[color_id]
            #    if color == 'R':
            #         Feature_image_color=np.zeros_like(input_np)
            #         Feature_image_color[:,:,0]=Feature_image
            #         Feature_image=Feature_image_color
            #    if color == 'G':
            #         Feature_image_color=np.zeros_like(input_np)
            #         Feature_image_color[:,:,1]=Feature_image
            #         Feature_image=Feature_image_color
            #    if color == 'B':
            #         Feature_image_color=np.zeros_like(input_np)
            #         Feature_image_color[:,:,2]=Feature_image
            #         Feature_image=Feature_image_color
            #    if color == 'Y':
            #         Feature_image_color=np.zeros_like(input_np)
            #         Feature_image_color[:,:,0]=Feature_image
            #         Feature_image_color[:,:,1]=Feature_image
            #         Feature_image=Feature_image_color
            #    if color == 'P':
            #         Feature_image_color=np.zeros_like(input_np)
            #         Feature_image_color[:,:,0]=Feature_image
            #         Feature_image_color[:,:,2]=Feature_image
            #         Feature_image=Feature_image_color
            #    if color == 'C':
            #         Feature_image_color=np.zeros_like(input_np)
            #         Feature_image_color[:,:,1]=Feature_image
            #         Feature_image_color[:,:,2]=Feature_image
            #         Feature_image=Feature_image_color

            #    color_id+=1



               # use Gamma correction
            #    Feature_image=np.power(Feature_image,1.5)
               # use Gamma correction

               #set each feature to singel color

               Feature_image=cv2.applyColorMap(Feature_image,cv2.COLORMAP_JET)


               Feature_image=0.3*Feature_image+0.7*input_np




               Feature_image=np.uint((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)) * 255)
               Feature_image=Feature_image.astype(np.uint8)
            #    path_to_featureimage=f"/home/qixuan/tmp/FeatureImage/FI{S}.jpg"
            #    cv2.imwrite(path_to_featureimage,Feature_image)
               Feature_image = cv2.cvtColor(Feature_image, cv2.COLOR_RGB2BGR)
               Feature_image_list.append(Feature_image)

    print("len of Features:",len(Interpretable_Features))

    return Feature_image_list


def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None, with_featuremaps=True):
    n_classes = dataset_constants[dataset]["num_classes"]

    # image_re=np.array(input)

    input=Image.fromarray(input)
    print("input shape",input.size)
  
    model = get_model(arch, n_classes, reduced_strides)
    tr=transform_input_img(input,img_size)
    # tr=transforms.Compose([
    #         transforms.Resize(500),
    #         transforms.CenterCrop(img_size),
    #         transforms.ToTensor(),
    #     ])

    #TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"])
    device = torch.device("cpu")
    if folder is None:
        folder = Path(f"tmp/{arch}/{dataset}/{seed}/")
    model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth",map_location=torch.device('cpu')))
    state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth",map_location=torch.device('cpu'))
    selection= torch.load(folder / f"SlDD_Selection_50.pt",map_location=torch.device('cpu'))
    state_dict['linear.selection']=selection
    
    feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict)
    model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse)
    model.load_state_dict(state_dict)

    input = tr(input)

    # path_to_input="/home/qixuan/tmp/FeatureImage/croped.jpg"
    # path_to_input_re="/home/qixuan/tmp/FeatureImage/re.jpg"
    # path_to_input_concat="/home/qixuan/tmp/FeatureImage/concate.jpg"
    # image_re=cv2.cvtColor(image_re, cv2.COLOR_RGB2BGR)

    # image_re=cv2.resize(image_re,(448,448))

    # image_np = (input * 255).clamp(0, 255).byte()
    # image_np = image_np.permute(1, 2, 0).numpy() 
    # image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
    # print("????",input.shape)
    # concat=np.vstack((image_re, image_np))
    # cv2.imwrite(path_to_input,image_np)
    # cv2.imwrite(path_to_input_re,image_re)
    # cv2.imwrite(path_to_input_concat,concat)

    input= input.unsqueeze(0)
    input= input.to(device)
    model = model.to(device)
    model.eval()
    
    with torch.no_grad():
        output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True)
        print("featuresmap size:",feature_maps.size())
        output_np=output.detach().cpu().numpy()
        output_np= np.argmax(output_np)+1
    
    if with_featuremaps:
        return output_np,model,feature_maps
    else:
        return output_np, model  

def get_options_from_trainingset(output, model, TR, device,with_other_class):       
    print("outputclass:",output)
    data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/")
    labels = pd.read_csv("image_class_labels.txt", sep=' ', names=['img_id', 'target'])
    namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name'])
    classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name'])
    options_output=labels[labels['target']==output]
    print(options_output)
    print(labels)
    options=options_output.sample(4)

    #mode 2
    if with_other_class:
        other_targets=random.sample([i for i in range(1,200)if i != output],3)
        all_targets=[output]+other_targets
        for tg in other_targets:
            others=labels[labels['target']==tg]
            options_others=others.sample(4)
            options = pd.concat([options, options_others], ignore_index=True)
    else:
        all_targets=[output]

    
    #shuffled_options = options.sample(frac=1).reset_index(drop=True)
    print("shuffled:",options)
    op=[]
    # resample_img_id_list=[]#resample filter
    W=model.linear.layer.weight# intergrate negative features
    model.eval()
    with torch.no_grad():
        for t in all_targets:

            # intergrate negative features
            W_class=W[t-1,:]
            features_id=[ f for f in W_class if f !=0 ]
            features_id_neg= [i+1 for i, x in enumerate(features_id) if x < 0] 

            # intergrate negative features
            
            image = cv2.imread(f"options_heatmap/{t}.jpg") 
            concatenate_class = np.array(image)
            concatenate_class = cv2.cvtColor(concatenate_class, cv2.COLOR_RGB2BGR)

            op.append((concatenate_class,features_id_neg))# intergrate negative features
    return op


def transform_input_img(input,img_size):
    h,w=input.size
    rate=h/w
    if h >= w:
        w_new=img_size
        h_new=int(w_new*rate)

    else:
        h_new=img_size
        w_new=int(h_new/rate)

    return transforms.Compose([
            transforms.Resize((w_new,h_new)),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
        ])




def post_next_image(OPT: str,key:str):
    if OPT==key:
        return ("Congradulations! you can simulate the prediction of Model this time",gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False))
    else:
        return (f"sorry, what the model predicted is {key}",gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False))



def get_features_on_interface(input):
    img_size=448
    output,model=genreate_intepriable_output(input,dataset="CUB2011", 
                                arch="resnet50",seed=123456, 
                                model_type="qsenn", n_features = 50,n_per_class=5,
                                img_size=448, reduced_strides=False, folder = None,with_featuremaps=False)
    TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"])
    device = torch.device("cpu")
    op= get_options_from_trainingset(output, model, TR, device,with_other_class=True)
    key=op[0][0]# intergrate negative features
    random.shuffle(op)
    option=[(op[0][0],"A"),
            (op[1][0],"B"),
            (op[2][0],"C"),
            (op[3][0],"D")]
    for value,char in option:
        if np.array_equal(value,key):
            key_op=char
            print("key",key_op)
    # if op[0][1]!=[]:
    #     option[0][1]=f"A,features{', '.join(map(str, op[0][1]))} are negative."
    # if op[1][1]!=[]:
    #     option[1][1]=f"B,features{', '.join(map(str, op[0][1]))} are negative."
    # if op[2][1]!=[]:
    #     option[2][1]=f"C,features{', '.join(map(str, op[0][1]))} are negative."
    # if op[3][1]!=[]:
    #     option[3][1]=f"D,features{', '.join(map(str, op[0][1]))} are negative."
        
    return option, key_op," These are some class explanations from our model for different classes,which of these classes has our model predicted?",gr.update(interactive=False)

def direct_inference(input):
    img_size=448
    output, model,feature_maps=genreate_intepriable_output(input,dataset="CUB2011", 
                                arch="resnet50",seed=123456, 
                                model_type="qsenn", n_features = 50,n_per_class=5,
                                img_size=448, reduced_strides=False, folder = None,with_featuremaps=True)
    # image_list=overlapping_features_on_input(model,output,feature_maps,input,target=None)
    # image_arrays = [np.array(img) for img in image_list]
    # concatenated_image = np.concatenate(image_arrays, axis=0)
    TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"])
    device = torch.device("cpu")
    concatenated_image=get_options_from_trainingset(output, model, TR, device, with_other_class=False)


    #original
    Input=Image.fromarray(input)
    tr=transform_input_img(Input,img_size)
    Input=tr(Input)
    image_np = (Input * 255).clamp(0, 255).byte()
    image_np = image_np.permute(1, 2, 0).numpy() 
    # image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)

    
    ORI= overlapping_features_on_input(model,output, feature_maps, image_np,output)#input  image_np
    ORI_arrays = [np.array(img) for img in ORI]
    concatenated_ORI = np.concatenate(ORI_arrays, axis=0)

    print(concatenated_ORI.shape,concatenated_image[0][0].shape)
    concatenated_image_final_array=np.concatenate((concatenated_ORI,concatenated_image[0][0]),axis=1)
    print(concatenated_image_final_array.shape)
    
    #original

    data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/")
    classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name'])
    output_name=classlist.loc[classlist['cl_id']==output,'class_name'].values[0]
    if concatenated_image[0][1]!=[]:
        output_name_and_features=f"{output_name}, features{', '.join(map(str, concatenated_image[0][1]))} are negative."
    else:
        output_name_and_features=f"{output_name}, all features are positive."


    return concatenated_image_final_array, output_name_and_features

def filter_with_diversity(featuremaps,output,weight):
    localizer = MultiKCrossChannelMaxPooledSum(range(1, 6), weight, None)
    localizer(output.to("cpu"),featuremaps.to("cpu"))

    locality, exlusive_locality = localizer.get_result()
    diversity = locality[4]
    diversity=diversity.item()
    return diversity