import gradio as gr
import numpy as np
from carbon_theme import Carbon

import numpy as np
import torch
import transformers

from art.estimators.classification.hugging_face import HuggingFaceClassifierPyTorch
from art.attacks.evasion import ProjectedGradientDescentPyTorch, AdversarialPatchPyTorch
from art.utils import load_dataset

from art.attacks.poisoning import PoisoningAttackBackdoor
from art.attacks.poisoning.perturbations import insert_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def clf_poison_evaluate(*args):
    
    attack = args[0]
    model_type = args[1]
    target_class = args[2]
    data_type = args[3]
    
    print('attack', attack)
    print('model_type', model_type)
    print('data_type', data_type)
    print('target_class', target_class)
    
    if model_type == "Example":
        model = transformers.AutoModelForImageClassification.from_pretrained(
            'facebook/deit-tiny-distilled-patch16-224',
            ignore_mismatched_sizes=True,
            force_download=True,
            num_labels=10
        )
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        loss_fn = torch.nn.CrossEntropyLoss()

        poison_hf_model = HuggingFaceClassifierPyTorch(
            model=model,
            loss=loss_fn,
            optimizer=optimizer,
            input_shape=(3, 224, 224),
            nb_classes=10,
            clip_values=(0, 1),
        )
        poison_hf_model.model.load_state_dict(torch.load('./state_dicts/deit_imagenette_clean_model.pt', map_location=device))
        
    if data_type == "Example":
        import torchvision
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor(),
        ])
        train_dataset = torchvision.datasets.ImageFolder(root="./data/imagenette2-320/train", transform=transform)
        labels = np.asarray(train_dataset.targets)
        classes = np.unique(labels)
        samples_per_class = 100

        x_subset = []
        y_subset = []

        for c in classes:
            indices = np.where(labels == c)[0][:samples_per_class]
            for i in indices:
                x_subset.append(train_dataset[i][0])
                y_subset.append(train_dataset[i][1])

        x_subset = np.stack(x_subset)
        y_subset = np.asarray(y_subset)
        label_names = [
            'fish',
            'dog',
            'cassette player',
            'chainsaw',
            'church',
            'french horn',
            'garbage truck',
            'gas pump',
            'golf ball',
            'parachutte',
        ]
        
    if attack == "Backdoor":
        from PIL import Image
        
        def poison_func(x):
            return insert_image(
                x,
                backdoor_path='./tmp.png',
                channels_first=True,
                random=False,
                x_shift=0,
                y_shift=0,
                size=(32, 32),
                mode='RGB',
                blend=0.8
            )
            
        backdoor = PoisoningAttackBackdoor(poison_func)
        source_class = 0
        target_class = label_names.index(target_class)
        poison_percent = 0.5

        x_poison = np.copy(x_subset)
        y_poison = np.copy(y_subset)
        is_poison = np.zeros(len(x_subset)).astype(bool)

        indices = np.where(y_subset == source_class)[0]
        num_poison = int(poison_percent * len(indices))

        for i in indices[:num_poison]:
            x_poison[i], _ = backdoor.poison(x_poison[i], [])
            y_poison[i] = target_class
            is_poison[i] = True

        poison_indices = np.where(is_poison)[0]
        print('fitting')
        print('x_poison', len(x_poison))
        print('y_poison', len(y_poison))
        poison_hf_model.fit(x_poison, y_poison, nb_epochs=2)
        print('finished fitting')
        
        clean_x = x_poison[~is_poison]
        clean_y = y_poison[~is_poison]

        outputs = poison_hf_model.predict(clean_x)
        clean_preds = np.argmax(outputs, axis=1)
        clean_acc = np.mean(clean_preds == clean_y)
        
        clean_out = []
        for i, im in enumerate(clean_x):
            clean_out.append( (im.transpose(1,2,0), label_names[clean_preds[i]]) )
        
        poison_x = x_poison[is_poison]
        poison_y = y_poison[is_poison]

        outputs = poison_hf_model.predict(poison_x)
        poison_preds = np.argmax(outputs, axis=1)
        poison_acc = np.mean(poison_preds == poison_y)
        
        poison_out = []
        for i, im in enumerate(poison_x):
            poison_out.append( (im.transpose(1,2,0), label_names[poison_preds[i]]) )
            
        
        return clean_out, poison_out, clean_acc, poison_acc

_, poison_out, _, _ = clf_poison_evaluate('Backdoor', 'Example', 'dog', 'Example')
print([i[1] for i in poison_out])
_, poison_out, _, _ = clf_poison_evaluate('Backdoor', 'Example', 'church', 'Example')
print([i[1] for i in poison_out])
_, poison_out, _, _ = clf_poison_evaluate('Backdoor', 'Example', 'gas pump', 'Example')
print([i[1] for i in poison_out])
_, poison_out, _, _ = clf_poison_evaluate('Backdoor', 'Example', 'golf ball', 'Example')
print([i[1] for i in poison_out])