stealth-edits / experiments /extract_features.py
qinghuazhou
Initial commit
85e172b
import os
import copy
import argparse
import numpy as np
from tqdm import tqdm
from util import utils
from util import extraction, evaluation
def cache_features(
model,
tok,
dataset,
hparams,
cache_features_file,
layers,
batch_size = 64,
static_context = '',
selection = None,
reverse_selection = False,
verbose = True
):
""" Function to load or cache features from dataset
"""
if os.path.exists(cache_features_file):
print('Loaded cached features file: ', cache_features_file)
cache_features_contents = utils.loadpickle(cache_features_file)
raw_case_ids = cache_features_contents['case_ids']
else:
# find raw requests and case_ids
raw_ds, _, _ = utils.load_dataset(tok, ds_name=dataset)
raw_requests = utils.extract_requests(raw_ds)
raw_case_ids = np.array([r['case_id'] for r in raw_requests])
# construct prompts and subjects
subjects = [static_context + r['prompt'].format(r['subject']) for r in raw_requests]
prompts = ['{}']*len(subjects)
# run multilayer feature extraction
_returns_across_layer = extraction.extract_multilayer_at_tokens(
model,
tok,
prompts,
subjects,
layers = layers,
module_template = hparams['rewrite_module_tmp'],
tok_type = 'prompt_final',
track = 'in',
batch_size = batch_size,
return_logits = False,
verbose = True
)
for key in _returns_across_layer:
_returns_across_layer[key] = _returns_across_layer[key]['in']
cache_features_contents = {}
for i in layers:
cache_features_contents[i] = \
_returns_across_layer[hparams['rewrite_module_tmp'].format(i)]
cache_features_contents['case_ids'] = raw_case_ids
cache_features_contents['prompts'] = np.array(prompts)
cache_features_contents['subjects'] = np.array(subjects)
utils.assure_path_exists(os.path.dirname(cache_features_file))
utils.savepickle(cache_features_file, cache_features_contents)
print('Saved features cache file: ', cache_features_file)
# filter cache_ppl_contents for selected samples
if selection is not None:
# load json file containing a dict with key case_ids containing a list of selected samples
select_case_ids = utils.loadjson(selection)['case_ids']
# boolean mask for selected samples w.r.t. all samples in the subjects pickle
matching = utils.generate_mask(raw_case_ids, np.array(select_case_ids))
if reverse_selection: matching = ~matching
# filter cache_ppl_contents for selected samples
cache_features_contents = utils.filter_for_selection(cache_features_contents, matching)
return cache_features_contents
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', default="gpt-j-6b", type=str, help='model to edit')
parser.add_argument(
'--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation')
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for extraction')
parser.add_argument(
'--layer', type=int, default=None, help='layer for extraction')
parser.add_argument(
'--cache_path', type=str, default='./cache/', help='output directory')
args = parser.parse_args()
# loading hyperparameters
hparams_path = f'./hparams/SE/{args.model}.json'
hparams = utils.loadjson(hparams_path)
# ensure save path exists
utils.assure_path_exists(args.cache_path)
# load model
model, tok = utils.load_model_tok(args.model)
# get layers to extract features from
if args.layer is not None:
layers = [args.layer]
cache_features_file = os.path.join(
args.cache_path, f'prompts_extract_{args.dataset}_{args.model}_layer{args.layer}.pickle'
)
else:
layers = evaluation.model_layer_indices[hparams['model_name']]
cache_features_file = os.path.join(
args.cache_path, f'prompts_extract_{args.dataset}_{args.model}.pickle'
)
# cache features
_ = cache_features(
model,
tok,
args.dataset,
hparams,
cache_features_file,
layers,
batch_size = args.batch_size,
verbose = True
)