Spaces:
Running
Running
from __future__ import absolute_import, division, print_function | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
import ot | |
from math import log | |
from collections import defaultdict, Counter | |
from transformers import AutoModelForMaskedLM, AutoTokenizer | |
class BaryScoreMetric: | |
def __init__(self, model_name="bert-base-uncased", last_layers=5, use_idfs=True, sinkhorn_ref=0.01): | |
""" | |
BaryScore metric | |
:param model_name: model name or path from HuggingFace Librairy | |
:param last_layers: last layer to use in the pretrained model | |
:param use_idfs: if true use idf costs else use uniform weights | |
:param sinkhorn_ref: weight of the KL in the SD | |
""" | |
self.model_name = model_name | |
self.load_tokenizer_and_model() | |
n = self.model.config.num_hidden_layers + 1 | |
assert n - last_layers > 0 | |
self.layers_to_consider = range(n - last_layers, n) | |
self.use_idfs = use_idfs | |
self.sinkhorn_ref = sinkhorn_ref | |
self.idfs = [] | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def prepare_idfs(self, hyps, refs): | |
""" | |
:param hyps: hypothesis list of string sentences has to be computed at corpus level | |
:param refs:reference list of string sentences has to be computed at corpus level | |
""" | |
t_hyps = self.tokenizer(hyps)['input_ids'] | |
t_refs = self.tokenizer(refs)['input_ids'] | |
idf_dict_ref = self.ref_list_to_idf(t_refs) | |
idf_dict_hyp = self.ref_list_to_idf(t_hyps) | |
idfs_tokenizer = (idf_dict_ref, idf_dict_hyp) | |
self.model_ids = idfs_tokenizer | |
return idf_dict_hyp, idf_dict_ref | |
def ref_list_to_idf(self, input_refs): | |
""" | |
:param input_refs: list of input reference | |
:return: idf dictionnary | |
""" | |
idf_count = Counter() | |
num_docs = len(input_refs) | |
idf_count.update(sum([list(set(i)) for i in input_refs], [])) | |
idf_dict = defaultdict(lambda: log((num_docs + 1) / (1))) | |
idf_dict.update({idx: log((num_docs + 1) / (c + 1)) for (idx, c) in idf_count.items()}) | |
return idf_dict | |
def load_tokenizer_and_model(self): | |
""" | |
Loading and initializing the chosen model and tokenizer | |
""" | |
tokenizer = AutoTokenizer.from_pretrained('{}'.format(self.model_name)) | |
model = AutoModelForMaskedLM.from_pretrained('{}'.format(self.model_name)) | |
model.config.output_hidden_states = True | |
model.eval() | |
self.tokenizer = tokenizer | |
self.model = model | |
def evaluate_batch(self, batch_hyps, batch_refs, idf_hyps=None, idf_ref=None): | |
""" | |
:param batch_hyps: hypothesis list of string sentences | |
:param batch_refs: reference list of string sentences | |
:param idf_hyps: idfs of hypothesis computed at corpus level | |
:param idf_ref: idfs of references computed at corpus level | |
:return: dictionnary of scores | |
""" | |
############################################### | |
## Extract Embeddings From Pretrained Models ## | |
############################################### | |
if isinstance(batch_hyps, str): | |
batch_hyps = [batch_hyps] | |
if isinstance(batch_refs, str): | |
batch_refs = [batch_refs] | |
nb_sentences = len(batch_refs) | |
baryscores = [] | |
assert len(batch_hyps) == len(batch_refs) | |
if (idf_hyps is None) and (idf_ref is None): | |
idf_hyps, idf_ref = self.model_ids | |
model = self.model.to(self.device) | |
with torch.no_grad(): | |
############################################### | |
## Extract Embeddings From Pretrained Models ## | |
############################################### | |
batch_refs = self.tokenizer(batch_refs, return_tensors='pt', padding=True, truncation=True).to(self.device) | |
batch_refs_embeddings_ = model(**batch_refs)[-1] | |
batch_hyps = self.tokenizer(batch_hyps, return_tensors='pt', padding=True, truncation=True).to(self.device) | |
batch_hyps_embeddings_ = model(**batch_hyps)[-1] | |
batch_refs_embeddings = [batch_refs_embeddings_[i] for i in list(self.layers_to_consider)] | |
batch_hyps_embeddings = [batch_hyps_embeddings_[i] for i in list(self.layers_to_consider)] | |
batch_refs_embeddings = torch.cat([i.unsqueeze(0) for i in batch_refs_embeddings]) | |
batch_refs_embeddings.div_(torch.norm(batch_refs_embeddings, dim=-1).unsqueeze(-1)) | |
batch_hyps_embeddings = torch.cat([i.unsqueeze(0) for i in batch_hyps_embeddings]) | |
batch_hyps_embeddings.div_(torch.norm(batch_hyps_embeddings, dim=-1).unsqueeze(-1)) | |
ref_tokens_id = batch_refs['input_ids'].cpu().tolist() | |
hyp_tokens_id = batch_hyps['input_ids'].cpu().tolist() | |
#################################### | |
## Unbatched BaryScore Prediction ## | |
#################################### | |
for index_sentence in tqdm(range(nb_sentences), 'BaryScore Progress'): | |
dict_score = {} | |
ref_ids_idf = batch_refs['input_ids'][index_sentence] | |
hyp_idf_ids = batch_hyps['input_ids'][index_sentence] | |
ref_tokens = [i for i in self.tokenizer.convert_ids_to_tokens(ref_tokens_id[index_sentence], | |
skip_special_tokens=False) if | |
i != self.tokenizer.pad_token] | |
hyp_tokens = [i for i in self.tokenizer.convert_ids_to_tokens(hyp_tokens_id[index_sentence], | |
skip_special_tokens=False) if | |
i != self.tokenizer.pad_token] | |
ref_ids = [k for k, w in enumerate(ref_tokens)] | |
hyp_ids = [k for k, w in enumerate(hyp_tokens)] | |
# With stop words | |
ref_idf_i = [idf_ref[i] for i in ref_ids_idf[ref_ids]] | |
hyp_idf_i = [idf_hyps[i] for i in hyp_idf_ids[hyp_ids]] | |
ref_embedding_i = batch_refs_embeddings[:, index_sentence, ref_ids, :] | |
hyp_embedding_i = batch_hyps_embeddings[:, index_sentence, hyp_ids, :] | |
measures_locations_ref = ref_embedding_i.permute(1, 0, 2).cpu().numpy().tolist() | |
measures_locations_ref = [np.array(i) for i in measures_locations_ref] | |
measures_locations_hyps = hyp_embedding_i.permute(1, 0, 2).cpu().numpy().tolist() | |
measures_locations_hyps = [np.array(i) for i in measures_locations_hyps] | |
# ADDED | |
measures_locations_ref = [np.array(i) for i in | |
np.array(measures_locations_ref).transpose(1, 0, 2).tolist()] | |
measures_locations_hyps = [np.array(i) for i in | |
np.array(measures_locations_hyps).transpose(1, 0, | |
2).tolist()] | |
if self.use_idfs: | |
######################### | |
## Use TF-IDF weights ## | |
######################### | |
baryscore = self.baryscore(measures_locations_ref, measures_locations_hyps, ref_idf_i, | |
hyp_idf_i) | |
else: | |
##################### | |
## Uniform Weights ## | |
##################### | |
baryscore = self.baryscore(measures_locations_ref, measures_locations_hyps, None, None) | |
for key, value in baryscore.items(): | |
dict_score['baryscore_{}'.format(key)] = value | |
baryscores.append(dict_score) | |
baryscores_dic = {} | |
for k in dict_score.keys(): | |
baryscores_dic[k] = [] | |
for score in baryscores: | |
baryscores_dic[k].append(score[k]) | |
return baryscores_dic | |
def baryscore(self, measures_locations_ref, measures_locations_hyps, weights_refs, weights_hyps): | |
""" | |
:param measures_locations_ref: input measure reference locations | |
:param measures_locations_hyps: input measure hypothesis locations | |
:param weights_refs: references weights in the Wasserstein Barycenters | |
:param weights_hyps: hypothesis weights in the Wasserstein Barycenters | |
:return: | |
""" | |
if weights_hyps is not None or weights_refs is not None: | |
assert weights_refs is not None | |
assert weights_hyps is not None | |
weights_hyps = np.array([i / sum(weights_hyps) for i in weights_hyps]).astype(np.float64) | |
weights_refs = np.array([i / sum(weights_refs) for i in weights_refs]).astype(np.float64) | |
self.n_layers = len(measures_locations_ref) | |
self.d_bert = measures_locations_ref[0].shape[1] | |
#################################### | |
## Compute Wasserstein Barycenter ## | |
#################################### | |
bary_ref = self.w_barycenter(measures_locations_ref, weights_refs) | |
bary_hyp = self.w_barycenter(measures_locations_hyps, weights_hyps) | |
################################################# | |
## Compute Wasserstein and Sinkhorn Divergence ## | |
################################################# | |
C = ot.dist(bary_ref, bary_hyp) | |
weights_first_barycenter = np.zeros((C.shape[0])) + 1 / C.shape[0] | |
weights_second_barycenter = np.zeros((C.shape[1])) + 1 / C.shape[1] | |
wasserstein_distance = ot.emd2(weights_first_barycenter, weights_second_barycenter, C, | |
log=True)[0] | |
dic_results = { | |
"W": wasserstein_distance, | |
} | |
for reg in [10, 1, 5, 1, 0.1, 0.5, 0.01, 0.001]: | |
wasserstein_sinkhorn = ot.bregman.sinkhorn2(weights_first_barycenter, weights_second_barycenter, C, | |
reg=reg, numItermax=10000).tolist() | |
if isinstance(wasserstein_sinkhorn, list): | |
wasserstein_sinkhorn = wasserstein_sinkhorn[0] # for POT==0.7.0 | |
dic_results['SD_{}'.format(reg)] = wasserstein_sinkhorn | |
return dic_results | |
def w_barycenter(self, measures_locations, weights): | |
""" | |
:param measures_locations: location of the discrete input measures | |
:param weights: weights of the input measures | |
:return: barycentrique distribution | |
""" | |
X_init = np.zeros((measures_locations[0].shape[0], self.d_bert)).astype(np.float64) | |
if weights is None: | |
measures_weights = [np.array( | |
[1 / measures_locations[0].shape[0]] * measures_locations[0].shape[0])] * self.n_layers | |
else: | |
measures_weights = [weights / sum(weights)] * self.n_layers | |
b = np.array([1 / measures_locations[0].shape[0]] * measures_locations[0].shape[0]).astype(np.float64) | |
mesure_bary = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, | |
b=b, numItermax=1000, verbose=False) | |
return mesure_bary | |
def supports_multi_ref(self): | |
""" | |
:return: BaryScore does not support multi ref | |
""" | |
return False | |
if __name__ == '__main__': | |
""" | |
Here you can find an example to use the BaryScore | |
""" | |
metric_call = BaryScoreMetric(use_idfs=False) | |
ref = [ | |
'I like my cakes very much', | |
'I hate these cakes!'] | |
hypothesis = ['I like my cakes very much', | |
'I like my cakes very much'] | |
metric_call.prepare_idfs(ref, hypothesis) | |
final_preds = metric_call.evaluate_batch(ref, hypothesis) | |
print(final_preds) |