Spaces:
Running
on
Zero
Running
on
Zero
File size: 15,915 Bytes
41a8e71 743ec89 41a8e71 4d0b859 41a8e71 743ec89 41a8e71 4d0b859 743ec89 41a8e71 743ec89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 |
from typing import List, Optional, Dict
import numpy as np
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
torch.set_grad_enabled(False)
def apply_top_p_with_epsilon(logits: torch.Tensor, top_p: float, epsilon: float = 1e-10) -> torch.Tensor:
"""
Applies a top-p (nucleus) filtering to logits but, instead of setting
the logits of non-selected tokens to -inf (which would result in zero probability),
sets them to log(epsilon), so that the support remains the same.
Parameters:
logits: Tensor of shape (batch, seq_len, vocab_size)
top_p: The nucleus threshold (e.g. 0.7, 0.8, etc.)
epsilon: The small value to assign to tokens not selected.
Returns:
new_logits: Tensor with the same shape as logits.
"""
# Compute probabilities from logits
probs = F.softmax(logits, dim=-1)
# Sort probabilities (descending) along the vocabulary dimension.
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
# Compute the cumulative sum along the sorted probabilities.
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Create a mask: True for tokens to keep.
# We keep tokens until cumulative_probs <= top_p.
keep_mask = cumulative_probs <= top_p
# Ensure that at least one token is kept per example: if none are kept, keep the top one.
# Here we check along the vocab dimension.
no_token_kept = keep_mask.sum(dim=-1, keepdim=True) == 0
if no_token_kept.any():
# For positions where no token was kept, set the first token (highest probability) to True.
# Note: torch.scatter_ returns a modified tensor.
# We create a tensor of zeros (False) and then scatter True into the first column.
fix_mask = torch.zeros_like(keep_mask, dtype=torch.bool)
fix_mask.scatter_(-1, torch.zeros_like(keep_mask[..., :1], dtype=torch.long), True)
keep_mask = torch.where(no_token_kept, fix_mask, keep_mask)
# Now, create new logits: copy the original logits.
new_logits = logits.clone()
# For tokens that are not kept (i.e. where keep_mask is False), set their logit to log(epsilon)
new_logits[~keep_mask] = torch.log(torch.tensor(epsilon, device=logits.device, dtype=logits.dtype))
return new_logits
class Mosaic(object):
def __init__(
self,
model_name_or_paths: List[str],
use_bfloat16: bool = True,
max_token_observed: int = 512,
unigram: Optional[str] = None,
custom_config: Optional[List[bool]] = None,
stupid_mode: bool = False,
one_model_mode: bool = False
) -> None:
"""
If `loaded_models` is provided, re-use any entries matching
model_name_or_paths; otherwise load and optionally register
into that dict.
"""
self.models = []
for model_name_or_path in model_name_or_paths:
# load from pre-trained hub or path
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16 if use_bfloat16 else torch.float32,
)
model.eval()
self.models.append(model)
print(f"Loaded model: {model_name_or_path}")
self.one_model_mode = one_model_mode
if stupid_mode:
self.max_iters = 0
else:
self.max_iters = 1000
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_paths[-1])
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.max_token_observed = max_token_observed
self.nb_models = len(self.models)
self.unigram_path = unigram
if custom_config is None:
custom_config = [False] * self.nb_models
self.custom_config = custom_config
def _tokenize(self, batch: list[str]) -> transformers.BatchEncoding:
encodings = self.tokenizer(
batch,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_token_observed,
return_token_type_ids=False)
return encodings
def trim_logits(self, logits, max_length=32000):
# Check the shape of the logits tensor
if logits.shape[2] > max_length:
# Slice the tensor to keep only the first max_length elements along the last dimension
logits = logits[:, :, :max_length]
return logits
@torch.inference_mode()
def _get_logits(self, encodings: transformers.BatchEncoding) -> List[torch.Tensor]:
# If one_model_mode is active, we simulate multiple models by applying top-p with different thresholds.
if self.one_model_mode:
# Compute base logits from the single model.
model = self.models[0]
device = next(model.parameters()).device
model_encodings = encodings.to(device)
base_logits = model(**model_encodings).logits
# Optionally trim logits:
# base_logits = self.trim_logits(base_logits)
# Define the top-p thresholds (e.g., four different values)
top_p_values = [0.7, 0.8, 0.9, 0.95]
# Epsilon value for non-selected tokens (you can adjust this if needed)
epsilon = 1e-10
logits_list = []
for top_p in top_p_values:
warped_logits = apply_top_p_with_epsilon(base_logits, top_p, epsilon)
logits_list.append(warped_logits)
else:
# Normal mode: use each model in self.models.
logits_list = []
for i, model in enumerate(self.models):
device = next(model.parameters()).device
model_encodings = encodings.to(device)
logits = model(**model_encodings).logits
# Optionally trim logits:
# logits = self.trim_logits(logits)
logits_list.append(logits)
if device.type == "cuda":
torch.cuda.synchronize(device)
if self.unigram_path:
batch_size, seq_len, voc_size = logits_list[0].shape
unigram_proba = torch.load(self.unigram_path)
unigram_proba += 1e-10
unigram_logits = torch.log(unigram_proba)
# Optionally center logits if needed:
logits = logits_list[0] - logits_list[0].mean(dim=-1, keepdim=True)
expanded_unigram_logits = unigram_logits.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, voc_size)
logits_list.append(expanded_unigram_logits)
return logits_list
def get_softmax_probabilities(self, input_text):
encodings = self._tokenize(input_text)
logits_list = self._get_logits(encodings)
probabilities_list = softmax_probabilities_all_models(logits_list)
return encodings, logits_list, probabilities_list
def compute_arimoto_torch(self, input_text, max_iters=1000):
encodings, logits_list, tensors_list = self.get_softmax_probabilities(input_text)
nb_models = len(tensors_list)
seq_len = len(encodings.input_ids[0])
voc_size = tensors_list[0].shape[-1]
device = tensors_list[0].device
# Move all tensors in tensors_list to the device of the first tensor
tensors_list = [tensor.to(device) for tensor in tensors_list]
# Stack all model predictions along a new dimension to form a (seq_len, nb_models, voc_size) tensor
probabilities_tensor = torch.stack([t[0] for t in tensors_list], dim=1).to(tensors_list[0].device)
# Run the Blahut-Arimoto algorithm on the entire batch
capacity, p = blahut_arimoto_torch(probabilities_tensor, max_iters=max_iters)
# Prepare the weighted sum tensor, initially zeros
weighted_sum_tensor = torch.zeros_like(tensors_list[0])
# Here, we need an additional mechanism if 'p' shapes or logic require different handling
# Assuming 'p' is now (seq_len, nb_models), apply weights to each model's output
for i in range(nb_models):
weighted_sum_tensor += p[:, i:i+1] * tensors_list[i]
return encodings, weighted_sum_tensor, tensors_list, p, logits_list
def compute_scores(self, input_text):
encodings, weighted_sum_tensor, probabilities_list, arimoto_weights, logits_list = self.compute_arimoto_torch(input_text, max_iters=self.max_iters)
log_ppl, ppl, nll = perplexity(encodings, weighted_sum_tensor)
ppl_list = perplexity_all_models(encodings, logits_list)
x_ppl_list = cross_entropy(weighted_sum_tensor, probabilities_list)
return log_ppl, x_ppl_list, arimoto_weights, nll, ppl_list
def compute_end_score(self, input_text):
encodings, weighted_sum_tensor, probabilities_list, arimoto_weights, logits_list = self.compute_arimoto_torch(input_text)
log_ppl, ppl, nll = perplexity(encodings, weighted_sum_tensor)
ppl_list = perplexity_all_models(encodings, logits_list)
x_ppl_list = cross_entropy(weighted_sum_tensor, probabilities_list)
log_ppl_value = log_ppl.item()
x_ppl_values = [x.item() for x in x_ppl_list]
final_score = log_ppl_value - x_ppl_values[0] #Ensure your "reference model" is given as first argument
return final_score
def perplexity(encodings, weighted_sum_tensor):
shifted_probabilities = weighted_sum_tensor[..., :-1, :].contiguous()
shifted_labels = encodings.input_ids[..., 1:].contiguous()
shifted_attention_mask = encodings.attention_mask[..., 1:].contiguous()
device = shifted_probabilities.device
# Ensure all tensors are moved to the same device
shifted_probabilities = shifted_probabilities.to(device)
shifted_labels = shifted_labels.to(device)
shifted_attention_mask = shifted_attention_mask.to(device)
actual_next_token_probabilities = torch.gather(shifted_probabilities, 2, shifted_labels.unsqueeze(-1)).squeeze(-1)
nll = -torch.log(actual_next_token_probabilities + 1e-12)
nll_masked = nll * shifted_attention_mask
# Calculate the average NLL per sequence, taking into account only the valid (non-padded) tokens
average_nll = torch.sum(nll_masked, dim=1) / torch.sum(shifted_attention_mask, dim=1)
# Calculate perplexity per sequence
perplexity = torch.exp(average_nll)
return average_nll, perplexity, nll_masked
def cross_entropy(weighted_sum_tensor, probabilities_list):
device = weighted_sum_tensor.device
x_ppl_list = []
# Compute log of weighted_sum_tensor outside the loop since it doesn't depend on m2_probabilities
log_M1 = torch.log(weighted_sum_tensor).to(device)
for m2_probabilities in probabilities_list:
m2_probabilities = m2_probabilities.to(device)
# Ensure m2_probabilities is correctly shaped for batch matrix multiplication
# log_M1 shape is already (batch_size, sequence_length, vocabulary_size)
# We need m2_probabilities in shape (batch_size, vocabulary_size, sequence_length) for bmm
m2_probabilities_transposed = m2_probabilities.transpose(1, 2)
# Perform batch matrix multiplication
# Resulting shape: (batch_size, sequence_length, sequence_length)
# We sum over the vocabulary dimension, effectively computing the dot product for each sequence position
dot_products = torch.bmm(log_M1, m2_probabilities_transposed)
# Since we're interested in the diagonal (dot products of corresponding vectors), we extract it
# The diagonal for each item in the batch gives us the dot products we're interested in
# torch.diagonal doesn't support batched operations directly, so we need to workaround
dot_products_diagonal = torch.einsum('bii->bi', dot_products) # Using einsum to extract diagonals for batch
# Compute the mean of the dot_products_diagonal across the sequence dimension
# This gives us the average dot product per sequence, which is then negated
x_ppl = -torch.mean(dot_products_diagonal, dim=1)
x_ppl_list.append(x_ppl)
x_ppl_tensor = torch.stack(x_ppl_list)
return x_ppl_list #, x_ppl_tensor
def softmax_probabilities_all_models(logits_list: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Calculates the softmax probabilities for the entire sequence of tokens for each model.
Parameters:
- logits_list: List[torch.Tensor]
A list containing the logits tensor for each model.
Returns:
- List[torch.Tensor]: A list of tensors, where each tensor is the softmax probabilities
for one model across the entire sequence of tokens.
"""
softmax_fn = torch.nn.Softmax(dim=-1)
probabilities_list = []
for logits in logits_list:
# Calculate softmax probabilities across the vocabulary for each token position
softmax_probabilities = softmax_fn(logits)
probabilities_list.append(softmax_probabilities)
return probabilities_list
def perplexity_logits(encoding, logits):
# Ensure encoding tensors are moved to the same device as logits
device = logits.device
logits = torch.clamp(logits, min=-20, max=50)
encoding_input_ids = encoding.input_ids.to(device)
encoding_attention_mask = encoding.attention_mask.to(device)
ce_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
shifted_logits = logits[..., :-1, :].contiguous()
shifted_labels = encoding_input_ids[..., 1:].contiguous()
shifted_attention_mask = encoding_attention_mask[..., 1:].contiguous()
# Calculate Cross-Entropy loss
cross_entropy_loss = ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels)
# Apply attention mask
masked_ce_loss = cross_entropy_loss * shifted_attention_mask
# Calculate perplexity
ppl = masked_ce_loss.sum(1) / shifted_attention_mask.sum(1)
# Move result to CPU and convert to numpy for further processing if needed
ppl = ppl.to("cpu").float().numpy()
return ppl
def perplexity_all_models(encoding, logits_list):
ppl_list = []
for logits in logits_list:
ppl = perplexity_logits(encoding, logits)
ppl_list.append(ppl)
return ppl_list
def blahut_arimoto_torch(W, epsilon=1e-6, max_iters=1000):
"""
Batch-process Blahut-Arimoto using PyTorch for multiple sequences.
"""
seq_len, nb_models, voc_size = W.shape
p = torch.full((seq_len, nb_models), 1.0 / nb_models, device=W.device, dtype=W.dtype)
prod_exp = torch.ones((seq_len, nb_models), device=W.device, dtype=W.dtype)
for _ in range(max_iters):
# Calculate the marginal probabilities
sum_p_w = torch.bmm(p.unsqueeze(1), W).squeeze(1) # Resultant shape: (seq_len, voc_size)
# Calculate normalized probabilities
W_normalized = W / sum_p_w.unsqueeze(1) # Broadcasting to shape (seq_len, nb_models, voc_size)
# Avoid numerical issues with logarithms
W_normalized[W_normalized == 0] = torch.finfo(W.dtype).eps
log_term = torch.log(W_normalized)
log_term[torch.isnan(log_term) | torch.isinf(log_term)] = 0
# Compute product exponentials and update probabilities
prod_exp = torch.exp(torch.sum(W * log_term, axis=2)) # Sum across voc_size
p_new = (p * prod_exp) / torch.sum(p * prod_exp, dim=1, keepdim=True)
# Check convergence
if torch.max(torch.abs(p - p_new)) < epsilon:
break
p = p_new
# Compute channel capacity
capacity = torch.log(torch.sum(p * prod_exp, dim=1)) / torch.log(torch.tensor(2.0, device=W.device))
return capacity, p |