Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| from __future__ import absolute_import, division, print_function | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| import ot | |
| from math import log | |
| from collections import defaultdict, Counter | |
| from transformers import AutoModelForMaskedLM, AutoTokenizer | |
| class BaryScoreMetric: | |
| def __init__(self, model_name="bert-base-uncased", last_layers=5, use_idfs=True, sinkhorn_ref=0.01): | |
| """ | |
| BaryScore metric | |
| :param model_name: model name or path from HuggingFace Librairy | |
| :param last_layers: last layer to use in the pretrained model | |
| :param use_idfs: if true use idf costs else use uniform weights | |
| :param sinkhorn_ref: weight of the KL in the SD | |
| """ | |
| self.model_name = model_name | |
| self.load_tokenizer_and_model() | |
| n = self.model.config.num_hidden_layers + 1 | |
| assert n - last_layers > 0 | |
| self.layers_to_consider = range(n - last_layers, n) | |
| self.use_idfs = use_idfs | |
| self.sinkhorn_ref = sinkhorn_ref | |
| self.idfs = [] | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| def prepare_idfs(self, hyps, refs): | |
| """ | |
| :param hyps: hypothesis list of string sentences has to be computed at corpus level | |
| :param refs:reference list of string sentences has to be computed at corpus level | |
| """ | |
| t_hyps = self.tokenizer(hyps)['input_ids'] | |
| t_refs = self.tokenizer(refs)['input_ids'] | |
| idf_dict_ref = self.ref_list_to_idf(t_refs) | |
| idf_dict_hyp = self.ref_list_to_idf(t_hyps) | |
| idfs_tokenizer = (idf_dict_ref, idf_dict_hyp) | |
| self.model_ids = idfs_tokenizer | |
| return idf_dict_hyp, idf_dict_ref | |
| def ref_list_to_idf(self, input_refs): | |
| """ | |
| :param input_refs: list of input reference | |
| :return: idf dictionnary | |
| """ | |
| idf_count = Counter() | |
| num_docs = len(input_refs) | |
| idf_count.update(sum([list(set(i)) for i in input_refs], [])) | |
| idf_dict = defaultdict(lambda: log((num_docs + 1) / (1))) | |
| idf_dict.update({idx: log((num_docs + 1) / (c + 1)) for (idx, c) in idf_count.items()}) | |
| return idf_dict | |
| def load_tokenizer_and_model(self): | |
| """ | |
| Loading and initializing the chosen model and tokenizer | |
| """ | |
| tokenizer = AutoTokenizer.from_pretrained('{}'.format(self.model_name)) | |
| model = AutoModelForMaskedLM.from_pretrained('{}'.format(self.model_name)) | |
| model.config.output_hidden_states = True | |
| model.eval() | |
| self.tokenizer = tokenizer | |
| self.model = model | |
| def evaluate_batch(self, batch_hyps, batch_refs, idf_hyps=None, idf_ref=None): | |
| """ | |
| :param batch_hyps: hypothesis list of string sentences | |
| :param batch_refs: reference list of string sentences | |
| :param idf_hyps: idfs of hypothesis computed at corpus level | |
| :param idf_ref: idfs of references computed at corpus level | |
| :return: dictionnary of scores | |
| """ | |
| ############################################### | |
| ## Extract Embeddings From Pretrained Models ## | |
| ############################################### | |
| if isinstance(batch_hyps, str): | |
| batch_hyps = [batch_hyps] | |
| if isinstance(batch_refs, str): | |
| batch_refs = [batch_refs] | |
| nb_sentences = len(batch_refs) | |
| baryscores = [] | |
| assert len(batch_hyps) == len(batch_refs) | |
| if (idf_hyps is None) and (idf_ref is None): | |
| idf_hyps, idf_ref = self.model_ids | |
| model = self.model.to(self.device) | |
| with torch.no_grad(): | |
| ############################################### | |
| ## Extract Embeddings From Pretrained Models ## | |
| ############################################### | |
| batch_refs = self.tokenizer(batch_refs, return_tensors='pt', padding=True, truncation=True).to(self.device) | |
| batch_refs_embeddings_ = model(**batch_refs)[-1] | |
| batch_hyps = self.tokenizer(batch_hyps, return_tensors='pt', padding=True, truncation=True).to(self.device) | |
| batch_hyps_embeddings_ = model(**batch_hyps)[-1] | |
| batch_refs_embeddings = [batch_refs_embeddings_[i] for i in list(self.layers_to_consider)] | |
| batch_hyps_embeddings = [batch_hyps_embeddings_[i] for i in list(self.layers_to_consider)] | |
| batch_refs_embeddings = torch.cat([i.unsqueeze(0) for i in batch_refs_embeddings]) | |
| batch_refs_embeddings.div_(torch.norm(batch_refs_embeddings, dim=-1).unsqueeze(-1)) | |
| batch_hyps_embeddings = torch.cat([i.unsqueeze(0) for i in batch_hyps_embeddings]) | |
| batch_hyps_embeddings.div_(torch.norm(batch_hyps_embeddings, dim=-1).unsqueeze(-1)) | |
| ref_tokens_id = batch_refs['input_ids'].cpu().tolist() | |
| hyp_tokens_id = batch_hyps['input_ids'].cpu().tolist() | |
| #################################### | |
| ## Unbatched BaryScore Prediction ## | |
| #################################### | |
| for index_sentence in tqdm(range(nb_sentences), 'BaryScore Progress'): | |
| dict_score = {} | |
| ref_ids_idf = batch_refs['input_ids'][index_sentence] | |
| hyp_idf_ids = batch_hyps['input_ids'][index_sentence] | |
| ref_tokens = [i for i in self.tokenizer.convert_ids_to_tokens(ref_tokens_id[index_sentence], | |
| skip_special_tokens=False) if | |
| i != self.tokenizer.pad_token] | |
| hyp_tokens = [i for i in self.tokenizer.convert_ids_to_tokens(hyp_tokens_id[index_sentence], | |
| skip_special_tokens=False) if | |
| i != self.tokenizer.pad_token] | |
| ref_ids = [k for k, w in enumerate(ref_tokens)] | |
| hyp_ids = [k for k, w in enumerate(hyp_tokens)] | |
| # With stop words | |
| ref_idf_i = [idf_ref[i] for i in ref_ids_idf[ref_ids]] | |
| hyp_idf_i = [idf_hyps[i] for i in hyp_idf_ids[hyp_ids]] | |
| ref_embedding_i = batch_refs_embeddings[:, index_sentence, ref_ids, :] | |
| hyp_embedding_i = batch_hyps_embeddings[:, index_sentence, hyp_ids, :] | |
| measures_locations_ref = ref_embedding_i.permute(1, 0, 2).cpu().numpy().tolist() | |
| measures_locations_ref = [np.array(i) for i in measures_locations_ref] | |
| measures_locations_hyps = hyp_embedding_i.permute(1, 0, 2).cpu().numpy().tolist() | |
| measures_locations_hyps = [np.array(i) for i in measures_locations_hyps] | |
| # ADDED | |
| measures_locations_ref = [np.array(i) for i in | |
| np.array(measures_locations_ref).transpose(1, 0, 2).tolist()] | |
| measures_locations_hyps = [np.array(i) for i in | |
| np.array(measures_locations_hyps).transpose(1, 0, | |
| 2).tolist()] | |
| if self.use_idfs: | |
| ######################### | |
| ## Use TF-IDF weights ## | |
| ######################### | |
| baryscore = self.baryscore(measures_locations_ref, measures_locations_hyps, ref_idf_i, | |
| hyp_idf_i) | |
| else: | |
| ##################### | |
| ## Uniform Weights ## | |
| ##################### | |
| baryscore = self.baryscore(measures_locations_ref, measures_locations_hyps, None, None) | |
| for key, value in baryscore.items(): | |
| dict_score['baryscore_{}'.format(key)] = value | |
| baryscores.append(dict_score) | |
| baryscores_dic = {} | |
| for k in dict_score.keys(): | |
| baryscores_dic[k] = [] | |
| for score in baryscores: | |
| baryscores_dic[k].append(score[k]) | |
| return baryscores_dic | |
| def baryscore(self, measures_locations_ref, measures_locations_hyps, weights_refs, weights_hyps): | |
| """ | |
| :param measures_locations_ref: input measure reference locations | |
| :param measures_locations_hyps: input measure hypothesis locations | |
| :param weights_refs: references weights in the Wasserstein Barycenters | |
| :param weights_hyps: hypothesis weights in the Wasserstein Barycenters | |
| :return: | |
| """ | |
| if weights_hyps is not None or weights_refs is not None: | |
| assert weights_refs is not None | |
| assert weights_hyps is not None | |
| weights_hyps = np.array([i / sum(weights_hyps) for i in weights_hyps]).astype(np.float64) | |
| weights_refs = np.array([i / sum(weights_refs) for i in weights_refs]).astype(np.float64) | |
| self.n_layers = len(measures_locations_ref) | |
| self.d_bert = measures_locations_ref[0].shape[1] | |
| #################################### | |
| ## Compute Wasserstein Barycenter ## | |
| #################################### | |
| bary_ref = self.w_barycenter(measures_locations_ref, weights_refs) | |
| bary_hyp = self.w_barycenter(measures_locations_hyps, weights_hyps) | |
| ################################################# | |
| ## Compute Wasserstein and Sinkhorn Divergence ## | |
| ################################################# | |
| C = ot.dist(bary_ref, bary_hyp) | |
| weights_first_barycenter = np.zeros((C.shape[0])) + 1 / C.shape[0] | |
| weights_second_barycenter = np.zeros((C.shape[1])) + 1 / C.shape[1] | |
| wasserstein_distance = ot.emd2(weights_first_barycenter, weights_second_barycenter, C, | |
| log=True)[0] | |
| dic_results = { | |
| "W": wasserstein_distance, | |
| } | |
| for reg in [10, 1, 5, 1, 0.1, 0.5, 0.01, 0.001]: | |
| wasserstein_sinkhorn = ot.bregman.sinkhorn2(weights_first_barycenter, weights_second_barycenter, C, | |
| reg=reg, numItermax=10000).tolist() | |
| if isinstance(wasserstein_sinkhorn, list): | |
| wasserstein_sinkhorn = wasserstein_sinkhorn[0] # for POT==0.7.0 | |
| dic_results['SD_{}'.format(reg)] = wasserstein_sinkhorn | |
| return dic_results | |
| def w_barycenter(self, measures_locations, weights): | |
| """ | |
| :param measures_locations: location of the discrete input measures | |
| :param weights: weights of the input measures | |
| :return: barycentrique distribution | |
| """ | |
| X_init = np.zeros((measures_locations[0].shape[0], self.d_bert)).astype(np.float64) | |
| if weights is None: | |
| measures_weights = [np.array( | |
| [1 / measures_locations[0].shape[0]] * measures_locations[0].shape[0])] * self.n_layers | |
| else: | |
| measures_weights = [weights / sum(weights)] * self.n_layers | |
| b = np.array([1 / measures_locations[0].shape[0]] * measures_locations[0].shape[0]).astype(np.float64) | |
| mesure_bary = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, | |
| b=b, numItermax=1000, verbose=False) | |
| return mesure_bary | |
| def supports_multi_ref(self): | |
| """ | |
| :return: BaryScore does not support multi ref | |
| """ | |
| return False | |
| if __name__ == '__main__': | |
| """ | |
| Here you can find an example to use the BaryScore | |
| """ | |
| metric_call = BaryScoreMetric(use_idfs=False) | |
| ref = [ | |
| 'I like my cakes very much', | |
| 'I hate these cakes!'] | |
| hypothesis = ['I like my cakes very much', | |
| 'I like my cakes very much'] | |
| metric_call.prepare_idfs(ref, hypothesis) | |
| final_preds = metric_call.evaluate_batch(ref, hypothesis) | |
| print(final_preds) |