import os
import sys
import argparse

import numpy as np
from tqdm import tqdm

import torch
device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu')

from util import utils
from stealth_edit import editors


def edit(args):

    # loading hyperparameters
    hparams_path = f'./hparams/SE/{args.model}.json'
    hparams = utils.loadjson(hparams_path)

    # save additional params to hparams
    hparams['Delta'] = args.Delta

    # add static context
    if args.static_context is not None:
        hparams['static_context'] = args.static_context

    # load model and tokenizer
    print('\nLoading model:', args.model)
    model, tok = utils.load_model_tok(model_name=args.model)

    # load dataset
    if (args.edit_mode == 'in-place') and (args.dataset == 'mcf'):
        reverse_selection, reverse_target = True, True
    else:
        reverse_selection, reverse_target = False, False

    print('Loading dataset:', args.dataset)
    ds, _, _ = utils.load_dataset(
        tok, 
        ds_name=args.dataset, 
        selection=args.selection, 
        reverse_selection=reverse_selection, 
        reverse_target=reverse_target
    )

    # find other feature vectors (from wikipedia dataset) 
    if args.other_pickle is not None:
        other_features = utils.loadpickle(args.other_pickle)['features']
        other_features = torch.from_numpy(other_features).to(device)
    else:
        other_features = None

    existing_files = [f for f in os.listdir(args.save_path) if f.endswith('.pickle')]
    sampled_case_ids = [int(f.split('.pickle')[0]) for f in existing_files]
    num_sampled = len(sampled_case_ids)

    if args.to_run is not None:
        args.sample_size = args.to_run + num_sampled

    print('Found {:} existing files in {:}'.format(len(existing_files), args.save_path))

    pbar = tqdm(total=args.sample_size)
    pbar.update(num_sampled)

    while num_sampled < args.sample_size:

        # sample a random request
        request_idx = np.random.randint(0, len(ds))

        # find subject request 
        request = ds.data[request_idx]['requested_rewrite']

        # find case id
        case_id = ds.data[request_idx]["case_id"]
        request['case_id'] = case_id

        if case_id in sampled_case_ids:
            continue

        # construct save path and check if already exists
        output_path = os.path.join(args.save_path, f'{case_id}.pickle')
        if os.path.isfile(output_path):
            continue

        if args.verbose:
            print('\n\nRunning {:}/{:} for request:'.format(num_sampled+1, args.sample_size))
            print(request)

        try:

            if args.edit_mode == 'in-place':

                edit_sample_results = editors.apply_edit(
                    request,
                    model,
                    tok,
                    layer = args.layer,
                    hparams = hparams,
                    other_features = other_features,
                    theta = args.theta,
                    verbose = args.verbose,
                )
            elif args.edit_mode in ['prompt', 'context', 'wikipedia']:

                edit_sample_results = editors.apply_attack(
                    request,
                    model,
                    tok,
                    layer = args.layer,
                    hparams = hparams,
                    other_features = other_features,
                    edit_mode = args.edit_mode,
                    theta = args.theta,
                    augmented_cache = args.augmented_cache,
                    verbose = args.verbose,
                )

            # Removing some keys from the result dict
            keys_to_remove = ['w1_weight', 'w1a_weight', 'w1b_weight', 'w1_bias', 'w2_weight', 'w2_bias', 'weights_to_modify']
            for key in keys_to_remove:
                if key in edit_sample_results:
                    edit_sample_results.pop(key, None)

            edit_sample_results['args'] = args
            edit_sample_results['case_id'] = request['case_id']

            utils.savepickle(output_path, edit_sample_results)
            if args.verbose: print('Saved results to:', output_path)

        except Exception as e:
            print('Failed for case_id:', case_id)
            print(e)

        num_sampled += 1
        pbar.update(1)

    pbar.close()


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    
    parser.add_argument(
        '--model', default="gpt-j-6b", type=str, help='model to edit')
    parser.add_argument(
        '--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation')

    parser.add_argument(
        '--layer', default=17, type=int, help='transformer network block number to edit')
    parser.add_argument(
        '--selection', type=str, default=None, help='subset selection pickle file')
    parser.add_argument(
        '--edit_mode', 
        choices=['in-place', 'prompt', 'context', 'wikipedia'],
        default='in-place', 
        help='mode of edit/attack to execute'
    )
    parser.add_argument(
        '--static_context', type=str, default=None, help='output directory')
    parser.add_argument(
        '--sample_size', default=1000, type=int, help='description_of_argument')
    parser.add_argument(
        '--to_run', default=None, type=int, help='description_of_argument')

    parser.add_argument(
        '--theta', default=0.005, type=float, help='`bias` for inserted f')
    parser.add_argument(
        '--Delta', default=50.0, type=float, help='magnitude of target response')

    parser.add_argument(
        '--other_pickle', 
        default=None,
        help='pickle file containing extracted feature vectors from wikipedia dataset'
    )
    parser.add_argument(
        '--augmented_cache', type=str, default=None, help='output directory')

    parser.add_argument(
        '--verbose', action="store_true")
    parser.add_argument(
        '--save_path', type=str, default='./results/tmp/', help='results path')

    args = parser.parse_args()

    # construct paths
    if (args.selection is not None) and ('{}' in args.selection):
        args.selection = args.selection.format(args.dataset, args.model)

    if (args.other_pickle is not None) and ('{}' in args.other_pickle):
        args.other_pickle = args.other_pickle.format(args.model, args.layer)

    # ensure results path exists
    args.save_path = os.path.join(args.save_path, f'{args.dataset}/{args.model}/layer{args.layer}/')
    utils.assure_path_exists(args.save_path)

    # run edits
    edit(args)