stealth-edits / evaluation /eval_ppl.py
qinghuazhou
Initial commit
85e172b
raw
history blame
7.79 kB
import os
import sys
import copy
import argparse
import numpy as np
from tqdm import tqdm
import torch
device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu')
from util import utils
from util import perplexity
from pytictoc import TicToc
pyt = TicToc() #create timer instance
def main_eval(args):
# loading hyperparameters
hparams_path = f'./hparams/SE/{args.model}.json'
hparams = utils.loadjson(hparams_path)
# find path
if (args.selection is not None) and ('{}' in args.selection):
args.selection = args.selection.format(args.dataset, args.model)
# find results path
args.save_path = os.path.join(args.save_path, f'{args.dataset}/{args.model}/layer{args.layer}/')
# create new folder under results path to save new results
output_dir = os.path.join(args.save_path, 'perplexity/')
utils.assure_path_exists(output_dir)
## LOAD MODEL ######################################################
# load model and tokenizer
model, tok = utils.load_model_tok(model_name=args.model)
# load activation function for MLP components of model
activation = utils.load_activation(hparams['activation'])
# load dataset
if (args.edit_mode == 'in-place') and (args.dataset == 'mcf'):
reverse_selection = True
reverse_target = True
else:
reverse_selection = False
reverse_target = False
print('Loading dataset:', args.dataset)
ds, _, _ = utils.load_dataset(tok, ds_name=args.dataset, selection=args.selection, reverse_selection=reverse_selection, reverse_target=reverse_target)
# find all requests and case_ids
dataset_requests = utils.extract_requests(ds)
case_ids = np.array([r['case_id'] for r in dataset_requests])
## LOAD DATA #######################################################
# find sample files to run (sample files named with case_id)
sample_files = np.array([f for f in os.listdir(args.save_path) if f.endswith('.pickle')])
if args.shuffle: sample_files = utils.shuffle_list(sample_files)
print('Number of pickle files:', len(sample_files))
print('Running files:', sample_files)
if len(sample_files)==0:
print('No files to run')
sys.exit()
## PROCESSING #######################################################
perplexity_arguments = {
'token_window': args.token_window,
'batch_size': args.batch_size,
'verbose': True
}
# find or generate cache for perplexity measures of other samples
cache_ppl_file = os.path.join(
args.cache_path,
f'inference_ppl_{args.dataset}_{args.model}_tw{args.token_window}.pickle'
)
cache_ppl_contents = perplexity.cache_ppl(
model,
tok,
dataset = args.dataset,
cache_ppl_file = cache_ppl_file,
selection = args.selection,
reverse_selection = reverse_selection,
**perplexity_arguments
)
assert np.array_equal(case_ids, cache_ppl_contents['case_ids'])
if args.eval_oap:
cache_ppl_oap_file = copy.deepcopy(cache_ppl_file)
cache_ppl_oap_file = cache_ppl_oap_file.replace('.pickle', '_static_context.pickle')
cache_ppl_oap_contents = perplexity.cache_ppl(
model,
tok,
dataset = args.dataset,
cache_ppl_file = cache_ppl_oap_file,
static_context=args.static_context,
selection = args.selection,
reverse_selection = reverse_selection,
**perplexity_arguments
)
assert np.array_equal(case_ids, cache_ppl_oap_contents['case_ids'])
else:
cache_ppl_oap_contents = None
cache_ppl_oap_file = None
from . import eval_utils
evaluator = eval_utils.PerplexityEvaluator(
model,
tok,
layer = args.layer,
hparams=hparams,
ds = ds,
edit_mode = args.edit_mode,
token_window = args.token_window,
batch_size = args.batch_size,
num_other_prompt_eval = args.num_other_prompt_eval,
num_aug_prompt_eval = args.num_aug_prompt_eval,
eval_op = args.eval_op,
eval_oap = args.eval_oap,
eval_ap = args.eval_ap,
eval_aug = args.eval_aug,
op_cache=cache_ppl_contents,
oap_cache=cache_ppl_oap_contents,
verbose = True
)
for sample_idx in range(len(sample_files)):
print('\n\nSample {:}/{:}'.format(sample_idx+1, len(sample_files)))
pyt.tic() #Start timer
try:
# load result pickle file
evaluator.load_sample(args.save_path, sample_files[sample_idx])
if args.exclusion:
if not evaluator.first_success_criteria():
continue
# evaluate target requests
evaluator.eval_targets(force_recompute=False)
if args.exclusion:
if not evaluator.second_success_criteria():
continue
# main evaluation
evaluator.evaluate()
# save results
evaluator.save_sample()
# clear sample
evaluator.clear_sample()
except Exception as e:
print('Failed for', sample_files[sample_idx])
print(e)
pyt.toc() #Stop timer
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', default="gpt-j-6b", type=str, help='model to edit')
parser.add_argument(
'--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation')
parser.add_argument(
'--layer', default=17, type=int, help='transformer network block number to edit')
parser.add_argument(
'--selection', type=str, default=None, help='output directory')
parser.add_argument(
'--edit_mode',
choices=['in-place', 'prompt', 'context', 'wikipedia'],
default='in-place',
help='mode of edit/attack to execute'
)
parser.add_argument(
'--static_context', type=str, default=None, help='output directory')
parser.add_argument(
'--cache_path', default='./cache/', type=str, help='path to cache')
parser.add_argument(
'--token_window', type=int, default=50, help='token window for perplexity measures')
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for inference')
parser.add_argument(
'--shuffle', action="store_true", help='shuffle samples to evaluate')
parser.add_argument(
'--eval_op', type=int, default=1, help='eval of attack context + prompts')
parser.add_argument(
'--eval_oap', type=int, default=0, help='eval of static context + prompts')
parser.add_argument(
'--eval_ap', type=int, default=0, help='eval of attack context + prompts')
parser.add_argument(
'--eval_aug', type=int, default=0, help='eval of attack context + prompts')
parser.add_argument(
'--num_other_prompt_eval', type=int, default=500, help='number of other prompts to evaluate')
parser.add_argument(
'--num_aug_prompt_eval', type=int, default=500, help='number of augmented prompts to evaluate')
parser.add_argument(
'--exclusion', type=int, default=1, help='eval of attack context + prompts')
parser.add_argument(
'--save_path', type=str, default='./results/tmp/', help='results path')
args = parser.parse_args()
# convert boolean parameters
args.eval_op = bool(args.eval_op )
args.eval_oap = bool(args.eval_oap)
args.eval_ap = bool(args.eval_ap )
args.shuffle = bool(args.shuffle )
args.exclusion = bool(args.exclusion)
# run main
main_eval(args)