Spaces:
Runtime error
Runtime error
# TODO consider if this can be collapsed back down into the pipeline_train.py | |
import argparse | |
import json | |
import logging | |
import random | |
import os | |
from sklearn.metrics import accuracy_score | |
from itertools import chain | |
from typing import List, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from transformers import BertTokenizer | |
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator | |
from BERT_rationale_benchmark.utils import ( | |
Annotation, | |
Evidence, | |
write_jsonl, | |
load_datasets, | |
load_documents, | |
) | |
from BERT_explainability.modules.BERT.BertForSequenceClassification import \ | |
BertForSequenceClassification as BertForSequenceClassificationTest | |
from BERT_explainability.modules.BERT.BERT_cls_lrp import \ | |
BertForSequenceClassification as BertForClsOrigLrp | |
from transformers import BertForSequenceClassification | |
from collections import OrderedDict | |
logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s') | |
logger = logging.getLogger(__name__) | |
# let's make this more or less deterministic (not resistent to restarts) | |
random.seed(12345) | |
np.random.seed(67890) | |
torch.manual_seed(10111213) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
import numpy as np | |
latex_special_token = ["!@#$%^&*()"] | |
def generate(text_list, attention_list, latex_file, color='red', rescale_value = False): | |
attention_list = attention_list[:len(text_list)] | |
if attention_list.max() == attention_list.min(): | |
attention_list = torch.zeros_like(attention_list) | |
else: | |
attention_list = 100 * (attention_list - attention_list.min()) / (attention_list.max() - attention_list.min()) | |
attention_list[attention_list < 1] = 0 | |
attention_list = attention_list.tolist() | |
text_list = [text_list[i].replace('$', '') for i in range(len(text_list))] | |
if rescale_value: | |
attention_list = rescale(attention_list) | |
word_num = len(text_list) | |
text_list = clean_word(text_list) | |
with open(latex_file,'w') as f: | |
f.write(r'''\documentclass[varwidth=150mm]{standalone} | |
\special{papersize=210mm,297mm} | |
\usepackage{color} | |
\usepackage{tcolorbox} | |
\usepackage{CJK} | |
\usepackage{adjustbox} | |
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt} | |
\begin{document} | |
\begin{CJK*}{UTF8}{gbsn}'''+'\n') | |
string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{'''+"\n" | |
for idx in range(word_num): | |
# string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} " | |
# print(text_list[idx]) | |
if '\#\#' in text_list[idx]: | |
token = text_list[idx].replace('\#\#', '') | |
string += "\\colorbox{%s!%s}{" % (color, attention_list[idx]) + "\\strut " + token + "}" | |
else: | |
string += " " + "\\colorbox{%s!%s}{" % (color, attention_list[idx]) + "\\strut " + text_list[idx] + "}" | |
string += "\n}}}" | |
f.write(string+'\n') | |
f.write(r'''\end{CJK*} | |
\end{document}''') | |
def clean_word(word_list): | |
new_word_list = [] | |
for word in word_list: | |
for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]: | |
if latex_sensitive in word: | |
word = word.replace(latex_sensitive, '\\'+latex_sensitive) | |
new_word_list.append(word) | |
return new_word_list | |
def scores_per_word_from_scores_per_token(input, tokenizer, input_ids, scores_per_id): | |
words = tokenizer.convert_ids_to_tokens(input_ids) | |
words = [word.replace('##', '') for word in words] | |
score_per_char = [] | |
# TODO: DELETE | |
input_ids_chars = [] | |
for word in words: | |
if word in ['[CLS]', '[SEP]', '[UNK]', '[PAD]']: | |
continue | |
input_ids_chars += list(word) | |
# TODO: DELETE | |
for i in range(len(scores_per_id)): | |
if words[i] in ['[CLS]', '[SEP]', '[UNK]', '[PAD]']: | |
continue | |
score_per_char += [scores_per_id[i]] * len(words[i]) | |
score_per_word = [] | |
start_idx = 0 | |
end_idx = 0 | |
# TODO: DELETE | |
words_from_chars = [] | |
for inp in input: | |
if start_idx >= len(score_per_char): | |
break | |
end_idx = end_idx + len(inp) | |
score_per_word.append(np.max(score_per_char[start_idx:end_idx])) | |
# TODO: DELETE | |
words_from_chars.append(''.join(input_ids_chars[start_idx:end_idx])) | |
start_idx = end_idx | |
if (words_from_chars[:-1] != input[:len(words_from_chars)-1]): | |
print(words_from_chars) | |
print(input[:len(words_from_chars)]) | |
print(words) | |
print(tokenizer.convert_ids_to_tokens(input_ids)) | |
assert False | |
return torch.tensor(score_per_word) | |
def get_input_words(input, tokenizer, input_ids): | |
words = tokenizer.convert_ids_to_tokens(input_ids) | |
words = [word.replace('##', '') for word in words] | |
input_ids_chars = [] | |
for word in words: | |
if word in ['[CLS]', '[SEP]', '[UNK]', '[PAD]']: | |
continue | |
input_ids_chars += list(word) | |
start_idx = 0 | |
end_idx = 0 | |
words_from_chars = [] | |
for inp in input: | |
if start_idx >= len(input_ids_chars): | |
break | |
end_idx = end_idx + len(inp) | |
words_from_chars.append(''.join(input_ids_chars[start_idx:end_idx])) | |
start_idx = end_idx | |
if (words_from_chars[:-1] != input[:len(words_from_chars)-1]): | |
print(words_from_chars) | |
print(input[:len(words_from_chars)]) | |
print(words) | |
print(tokenizer.convert_ids_to_tokens(input_ids)) | |
assert False | |
return words_from_chars | |
def bert_tokenize_doc(doc: List[List[str]], tokenizer, special_token_map) -> Tuple[List[List[str]], List[List[Tuple[int, int]]]]: | |
""" Tokenizes a document and returns [start, end) spans to map the wordpieces back to their source words""" | |
sents = [] | |
sent_token_spans = [] | |
for sent in doc: | |
tokens = [] | |
spans = [] | |
start = 0 | |
for w in sent: | |
if w in special_token_map: | |
tokens.append(w) | |
else: | |
tokens.extend(tokenizer.tokenize(w)) | |
end = len(tokens) | |
spans.append((start, end)) | |
start = end | |
sents.append(tokens) | |
sent_token_spans.append(spans) | |
return sents, sent_token_spans | |
def initialize_models(params: dict, batch_first: bool, use_half_precision=False): | |
assert batch_first | |
max_length = params['max_length'] | |
tokenizer = BertTokenizer.from_pretrained(params['bert_vocab']) | |
pad_token_id = tokenizer.pad_token_id | |
cls_token_id = tokenizer.cls_token_id | |
sep_token_id = tokenizer.sep_token_id | |
bert_dir = params['bert_dir'] | |
evidence_classes = dict((y, x) for (x, y) in enumerate(params['evidence_classifier']['classes'])) | |
evidence_classifier = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=len(evidence_classes)) | |
word_interner = tokenizer.vocab | |
de_interner = tokenizer.ids_to_tokens | |
return evidence_classifier, word_interner, de_interner, evidence_classes, tokenizer | |
BATCH_FIRST = True | |
def extract_docid_from_dataset_element(element): | |
return next(iter(element.evidences))[0].docid | |
def extract_evidence_from_dataset_element(element): | |
return next(iter(element.evidences)) | |
def main(): | |
parser = argparse.ArgumentParser(description="""Trains a pipeline model. | |
Step 1 is evidence identification, that is identify if a given sentence is evidence or not | |
Step 2 is evidence classification, that is given an evidence sentence, classify the final outcome for the final task | |
(e.g. sentiment or significance). | |
These models should be separated into two separate steps, but at the moment: | |
* prep data (load, intern documents, load json) | |
* convert data for evidence identification - in the case of training data we take all the positives and sample some | |
negatives | |
* side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a | |
broader sampling of negative values. | |
* train evidence identification | |
* convert data for evidence classification - take all rationales + decisions and use this as input | |
* train evidence classification | |
* decode first the evidence, then run classification for each split | |
""", formatter_class=argparse.RawTextHelpFormatter) | |
parser.add_argument('--data_dir', dest='data_dir', required=True, | |
help='Which directory contains a {train,val,test}.jsonl file?') | |
parser.add_argument('--output_dir', dest='output_dir', required=True, | |
help='Where shall we write intermediate models + final data to?') | |
parser.add_argument('--model_params', dest='model_params', required=True, | |
help='JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.') | |
args = parser.parse_args() | |
assert BATCH_FIRST | |
os.makedirs(args.output_dir, exist_ok=True) | |
with open(args.model_params, 'r') as fp: | |
logger.info(f'Loading model parameters from {args.model_params}') | |
model_params = json.load(fp) | |
logger.info(f'Params: {json.dumps(model_params, indent=2, sort_keys=True)}') | |
train, val, test = load_datasets(args.data_dir) | |
docids = set(e.docid for e in | |
chain.from_iterable(chain.from_iterable(map(lambda ann: ann.evidences, chain(train, val, test))))) | |
documents = load_documents(args.data_dir, docids) | |
logger.info(f'Loaded {len(documents)} documents') | |
evidence_classifier, word_interner, de_interner, evidence_classes, tokenizer = \ | |
initialize_models(model_params, batch_first=BATCH_FIRST) | |
logger.info(f'We have {len(word_interner)} wordpieces') | |
cache = os.path.join(args.output_dir, 'preprocessed.pkl') | |
if os.path.exists(cache): | |
logger.info(f'Loading interned documents from {cache}') | |
(interned_documents) = torch.load(cache) | |
else: | |
logger.info(f'Interning documents') | |
interned_documents = {} | |
for d, doc in documents.items(): | |
encoding = tokenizer.encode_plus( | |
doc, | |
add_special_tokens=True, | |
max_length=model_params['max_length'], | |
return_token_type_ids=False, | |
pad_to_max_length=False, | |
return_attention_mask=True, | |
return_tensors='pt', | |
truncation=True, | |
) | |
interned_documents[d] = encoding | |
torch.save((interned_documents), cache) | |
evidence_classifier = evidence_classifier.cuda() | |
optimizer = None | |
scheduler = None | |
save_dir = args.output_dir | |
logging.info(f'Beginning training classifier') | |
evidence_classifier_output_dir = os.path.join(save_dir, 'classifier') | |
os.makedirs(save_dir, exist_ok=True) | |
os.makedirs(evidence_classifier_output_dir, exist_ok=True) | |
model_save_file = os.path.join(evidence_classifier_output_dir, 'classifier.pt') | |
epoch_save_file = os.path.join(evidence_classifier_output_dir, 'classifier_epoch_data.pt') | |
device = next(evidence_classifier.parameters()).device | |
if optimizer is None: | |
optimizer = torch.optim.Adam(evidence_classifier.parameters(), lr=model_params['evidence_classifier']['lr']) | |
criterion = nn.CrossEntropyLoss(reduction='none') | |
batch_size = model_params['evidence_classifier']['batch_size'] | |
epochs = model_params['evidence_classifier']['epochs'] | |
patience = model_params['evidence_classifier']['patience'] | |
max_grad_norm = model_params['evidence_classifier'].get('max_grad_norm', None) | |
class_labels = [k for k, v in sorted(evidence_classes.items())] | |
results = { | |
'train_loss': [], | |
'train_f1': [], | |
'train_acc': [], | |
'val_loss': [], | |
'val_f1': [], | |
'val_acc': [], | |
} | |
best_epoch = -1 | |
best_val_acc = 0 | |
best_val_loss = float('inf') | |
best_model_state_dict = None | |
start_epoch = 0 | |
epoch_data = {} | |
if os.path.exists(epoch_save_file): | |
logging.info(f'Restoring model from {model_save_file}') | |
evidence_classifier.load_state_dict(torch.load(model_save_file)) | |
epoch_data = torch.load(epoch_save_file) | |
start_epoch = epoch_data['epoch'] + 1 | |
# handle finishing because patience was exceeded or we didn't get the best final epoch | |
if bool(epoch_data.get('done', 0)): | |
start_epoch = epochs | |
results = epoch_data['results'] | |
best_epoch = start_epoch | |
best_model_state_dict = OrderedDict({k: v.cpu() for k, v in evidence_classifier.state_dict().items()}) | |
logging.info(f'Restoring training from epoch {start_epoch}') | |
logging.info(f'Training evidence classifier from epoch {start_epoch} until epoch {epochs}') | |
optimizer.zero_grad() | |
for epoch in range(start_epoch, epochs): | |
epoch_train_data = random.sample(train, k=len(train)) | |
epoch_train_loss = 0 | |
epoch_training_acc = 0 | |
evidence_classifier.train() | |
logging.info( | |
f'Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples') | |
for batch_start in range(0, len(epoch_train_data), batch_size): | |
batch_elements = epoch_train_data[batch_start:min(batch_start + batch_size, len(epoch_train_data))] | |
targets = [evidence_classes[s.classification] for s in batch_elements] | |
targets = torch.tensor(targets, dtype=torch.long, device=device) | |
samples_encoding = [interned_documents[extract_docid_from_dataset_element(s)] for s in batch_elements] | |
input_ids = torch.stack([samples_encoding[i]['input_ids'] for i in range(len(samples_encoding))]).squeeze( | |
1).to(device) | |
attention_masks = torch.stack( | |
[samples_encoding[i]['attention_mask'] for i in range(len(samples_encoding))]).squeeze(1).to(device) | |
preds = evidence_classifier(input_ids=input_ids, attention_mask=attention_masks)[0] | |
epoch_training_acc += accuracy_score(preds.argmax(dim=1).cpu(), targets.cpu(), normalize=False) | |
loss = criterion(preds, targets.to(device=preds.device)).sum() | |
epoch_train_loss += loss.item() | |
loss.backward() | |
assert loss == loss # for nans | |
if max_grad_norm: | |
torch.nn.utils.clip_grad_norm_(evidence_classifier.parameters(), max_grad_norm) | |
optimizer.step() | |
if scheduler: | |
scheduler.step() | |
optimizer.zero_grad() | |
epoch_train_loss /= len(epoch_train_data) | |
epoch_training_acc /= len(epoch_train_data) | |
assert epoch_train_loss == epoch_train_loss # for nans | |
results['train_loss'].append(epoch_train_loss) | |
logging.info(f'Epoch {epoch} training loss {epoch_train_loss}') | |
logging.info(f'Epoch {epoch} training accuracy {epoch_training_acc}') | |
with torch.no_grad(): | |
epoch_val_loss = 0 | |
epoch_val_acc = 0 | |
epoch_val_data = random.sample(val, k=len(val)) | |
evidence_classifier.eval() | |
val_batch_size = 32 | |
logging.info( | |
f'Validating with {len(epoch_val_data) // val_batch_size} batches with {len(epoch_val_data)} examples') | |
for batch_start in range(0, len(epoch_val_data), val_batch_size): | |
batch_elements = epoch_val_data[batch_start:min(batch_start + val_batch_size, len(epoch_val_data))] | |
targets = [evidence_classes[s.classification] for s in batch_elements] | |
targets = torch.tensor(targets, dtype=torch.long, device=device) | |
samples_encoding = [interned_documents[extract_docid_from_dataset_element(s)] for s in batch_elements] | |
input_ids = torch.stack( | |
[samples_encoding[i]['input_ids'] for i in range(len(samples_encoding))]).squeeze(1).to(device) | |
attention_masks = torch.stack( | |
[samples_encoding[i]['attention_mask'] for i in range(len(samples_encoding))]).squeeze(1).to( | |
device) | |
preds = evidence_classifier(input_ids=input_ids, attention_mask=attention_masks)[0] | |
epoch_val_acc += accuracy_score(preds.argmax(dim=1).cpu(), targets.cpu(), normalize=False) | |
loss = criterion(preds, targets.to(device=preds.device)).sum() | |
epoch_val_loss += loss.item() | |
epoch_val_loss /= len(val) | |
epoch_val_acc /= len(val) | |
results["val_acc"].append(epoch_val_acc) | |
results["val_loss"] = epoch_val_loss | |
logging.info(f'Epoch {epoch} val loss {epoch_val_loss}') | |
logging.info(f'Epoch {epoch} val acc {epoch_val_acc}') | |
if epoch_val_acc > best_val_acc or (epoch_val_acc == best_val_acc and epoch_val_loss < best_val_loss): | |
best_model_state_dict = OrderedDict({k: v.cpu() for k, v in evidence_classifier.state_dict().items()}) | |
best_epoch = epoch | |
best_val_acc = epoch_val_acc | |
best_val_loss = epoch_val_loss | |
epoch_data = { | |
'epoch': epoch, | |
'results': results, | |
'best_val_acc': best_val_acc, | |
'done': 0, | |
} | |
torch.save(evidence_classifier.state_dict(), model_save_file) | |
torch.save(epoch_data, epoch_save_file) | |
logging.debug(f'Epoch {epoch} new best model with val accuracy {epoch_val_acc}') | |
if epoch - best_epoch > patience: | |
logging.info(f'Exiting after epoch {epoch} due to no improvement') | |
epoch_data['done'] = 1 | |
torch.save(epoch_data, epoch_save_file) | |
break | |
epoch_data['done'] = 1 | |
epoch_data['results'] = results | |
torch.save(epoch_data, epoch_save_file) | |
evidence_classifier.load_state_dict(best_model_state_dict) | |
evidence_classifier = evidence_classifier.to(device=device) | |
evidence_classifier.eval() | |
# test | |
test_classifier = BertForSequenceClassificationTest.from_pretrained(model_params['bert_dir'], | |
num_labels=len(evidence_classes)).to(device) | |
orig_lrp_classifier = BertForClsOrigLrp.from_pretrained(model_params['bert_dir'], | |
num_labels=len(evidence_classes)).to(device) | |
if os.path.exists(epoch_save_file): | |
logging.info(f'Restoring model from {model_save_file}') | |
test_classifier.load_state_dict(torch.load(model_save_file)) | |
orig_lrp_classifier.load_state_dict(torch.load(model_save_file)) | |
test_classifier.eval() | |
orig_lrp_classifier.eval() | |
test_batch_size = 1 | |
logging.info( | |
f'Testing with {len(test) // test_batch_size} batches with {len(test)} examples') | |
# explainability | |
explanations = Generator(test_classifier) | |
explanations_orig_lrp = Generator(orig_lrp_classifier) | |
method = "transformer_attribution" | |
method_folder = {"transformer_attribution": "ours", "partial_lrp": "partial_lrp", "last_attn": "last_attn", | |
"attn_gradcam": "attn_gradcam", "lrp": "lrp", "rollout": "rollout", | |
"ground_truth": "ground_truth", "generate_all": "generate_all"} | |
method_expl = {"transformer_attribution": explanations.generate_LRP, | |
"partial_lrp": explanations_orig_lrp.generate_LRP_last_layer, | |
"last_attn": explanations_orig_lrp.generate_attn_last_layer, | |
"attn_gradcam": explanations_orig_lrp.generate_attn_gradcam, | |
"lrp": explanations_orig_lrp.generate_full_lrp, | |
"rollout": explanations_orig_lrp.generate_rollout} | |
os.makedirs(os.path.join(args.output_dir, method_folder[method]), exist_ok=True) | |
result_files = [] | |
for i in range(5,85,5): | |
result_files.append(open(os.path.join(args.output_dir, '{0}/identifier_results_{1}.json').format(method_folder[method], i), 'w')) | |
j = 0 | |
for batch_start in range(0, len(test), test_batch_size): | |
batch_elements = test[batch_start:min(batch_start + test_batch_size, len(test))] | |
targets = [evidence_classes[s.classification] for s in batch_elements] | |
targets = torch.tensor(targets, dtype=torch.long, device=device) | |
samples_encoding = [interned_documents[extract_docid_from_dataset_element(s)] for s in batch_elements] | |
input_ids = torch.stack( | |
[samples_encoding[i]['input_ids'] for i in range(len(samples_encoding))]).squeeze(1).to(device) | |
attention_masks = torch.stack( | |
[samples_encoding[i]['attention_mask'] for i in range(len(samples_encoding))]).squeeze(1).to( | |
device) | |
preds = test_classifier(input_ids=input_ids, attention_mask=attention_masks)[0] | |
for s in batch_elements: | |
doc_name = extract_docid_from_dataset_element(s) | |
inp = documents[doc_name].split() | |
classification = "neg" if targets.item() == 0 else "pos" | |
is_classification_correct = 1 if preds.argmax(dim=1) == targets else 0 | |
if method == "generate_all": | |
file_name ="{0}_{1}_{2}.tex".format(j, classification, is_classification_correct) | |
GT_global = os.path.join(args.output_dir, '{0}/visual_results_{1}.pdf').format( | |
method_folder["ground_truth"], j) | |
GT_ours = os.path.join(args.output_dir, '{0}/{1}_GT_{2}_{3}.pdf').format( | |
method_folder["transformer_attribution"], j, classification, is_classification_correct) | |
CF_ours = os.path.join(args.output_dir, '{0}/{1}_CF.pdf').format( | |
method_folder["transformer_attribution"], j) | |
GT_partial = os.path.join(args.output_dir, '{0}/{1}_GT_{2}_{3}.pdf').format( | |
method_folder["partial_lrp"], j, classification, is_classification_correct) | |
CF_partial = os.path.join(args.output_dir, '{0}/{1}_CF.pdf').format( | |
method_folder["partial_lrp"], j) | |
GT_gradcam = os.path.join(args.output_dir, '{0}/{1}_GT_{2}_{3}.pdf').format( | |
method_folder["attn_gradcam"], j, classification, is_classification_correct) | |
CF_gradcam = os.path.join(args.output_dir, '{0}/{1}_CF.pdf').format( | |
method_folder["attn_gradcam"], j) | |
GT_lrp = os.path.join(args.output_dir, '{0}/{1}_GT_{2}_{3}.pdf').format( | |
method_folder["lrp"], j, classification, is_classification_correct) | |
CF_lrp = os.path.join(args.output_dir, '{0}/{1}_CF.pdf').format( | |
method_folder["lrp"], j) | |
GT_lastattn = os.path.join(args.output_dir, '{0}/{1}_GT_{2}_{3}.pdf').format( | |
method_folder["last_attn"], j, classification, is_classification_correct) | |
GT_rollout = os.path.join(args.output_dir, '{0}/{1}_GT_{2}_{3}.pdf').format( | |
method_folder["rollout"], j, classification, is_classification_correct) | |
with open(file_name, 'w') as f: | |
f.write(r'''\documentclass[varwidth]{standalone} | |
\usepackage{color} | |
\usepackage{tcolorbox} | |
\usepackage{CJK} | |
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt} | |
\begin{document} | |
\begin{CJK*}{UTF8}{gbsn} | |
{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{ | |
\setlength{\tabcolsep}{2pt} % Default value: 6pt | |
\begin{tabular}{ccc} | |
\includegraphics[width=0.32\linewidth]{''' + GT_global + '''}& | |
\includegraphics[width=0.32\linewidth]{''' + GT_ours + '''}& | |
\includegraphics[width=0.32\linewidth]{''' + CF_ours + '''}\\\\ | |
(a) & (b) & (c)\\\\ | |
\includegraphics[width=0.32\linewidth]{''' + GT_partial + '''}& | |
\includegraphics[width=0.32\linewidth]{''' + CF_partial + '''}& | |
\includegraphics[width=0.32\linewidth]{''' + GT_gradcam + '''}\\\\ | |
(d) & (e) & (f)\\\\ | |
\includegraphics[width=0.32\linewidth]{''' + CF_gradcam + '''}& | |
\includegraphics[width=0.32\linewidth]{''' + GT_lrp + '''}& | |
\includegraphics[width=0.32\linewidth]{''' + CF_lrp + '''}\\\\ | |
(g) & (h) & (i)\\\\ | |
\includegraphics[width=0.32\linewidth]{''' + GT_lastattn + '''}& | |
\includegraphics[width=0.32\linewidth]{''' + GT_rollout + '''}&\\\\ | |
(j) & (k)&\\\\ | |
\end{tabular} | |
}}} | |
\end{CJK*} | |
\end{document} | |
)''') | |
j += 1 | |
break | |
if method == "ground_truth": | |
inp_cropped = get_input_words(inp, tokenizer, input_ids[0]) | |
cam = torch.zeros(len(inp_cropped)) | |
for evidence in extract_evidence_from_dataset_element(s): | |
start_idx = evidence.start_token | |
if start_idx >= len(cam): | |
break | |
end_idx = evidence.end_token | |
cam[start_idx:end_idx] = 1 | |
generate(inp_cropped, cam, | |
(os.path.join(args.output_dir, '{0}/visual_results_{1}.tex').format(method_folder[method], | |
j)), color="green") | |
j = j + 1 | |
break | |
text = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
classification = "neg" if targets.item() == 0 else "pos" | |
is_classification_correct = 1 if preds.argmax(dim=1) == targets else 0 | |
target_idx = targets.item() | |
cam_target = method_expl[method](input_ids=input_ids, attention_mask=attention_masks, index=target_idx)[0] | |
cam_target = cam_target.clamp(min=0) | |
generate(text, cam_target, | |
(os.path.join(args.output_dir, '{0}/{1}_GT_{2}_{3}.tex').format( | |
method_folder[method], j, classification, is_classification_correct))) | |
if method in ["transformer_attribution", "partial_lrp", "attn_gradcam", "lrp"]: | |
cam_false_class = method_expl[method](input_ids=input_ids, attention_mask=attention_masks, index=1-target_idx)[0] | |
cam_false_class = cam_false_class.clamp(min=0) | |
generate(text, cam_false_class, | |
(os.path.join(args.output_dir, '{0}/{1}_CF.tex').format( | |
method_folder[method], j))) | |
cam = cam_target | |
cam = scores_per_word_from_scores_per_token(inp, tokenizer,input_ids[0], cam) | |
j = j + 1 | |
doc_name = extract_docid_from_dataset_element(s) | |
hard_rationales = [] | |
for res, i in enumerate(range(5, 85, 5)): | |
print("calculating top ", i) | |
_, indices = cam.topk(k=i) | |
for index in indices.tolist(): | |
hard_rationales.append({ | |
"start_token": index, | |
"end_token": index+1 | |
}) | |
result_dict = { | |
"annotation_id": doc_name, | |
"rationales": [{ | |
"docid": doc_name, | |
"hard_rationale_predictions": hard_rationales | |
}], | |
} | |
result_files[res].write(json.dumps(result_dict) + "\n") | |
for i in range(len(result_files)): | |
result_files[i].close() | |
if __name__ == '__main__': | |
main() | |