MOSAIC / mosaic.py
BaggerOfWords's picture
added default configs
41a8e71
raw
history blame
16.7 kB
from typing import List, Optional, Dict
import numpy as np
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
torch.set_grad_enabled(False)
def apply_top_p_with_epsilon(logits: torch.Tensor, top_p: float, epsilon: float = 1e-10) -> torch.Tensor:
"""
Applies a top-p (nucleus) filtering to logits but, instead of setting
the logits of non-selected tokens to -inf (which would result in zero probability),
sets them to log(epsilon), so that the support remains the same.
Parameters:
logits: Tensor of shape (batch, seq_len, vocab_size)
top_p: The nucleus threshold (e.g. 0.7, 0.8, etc.)
epsilon: The small value to assign to tokens not selected.
Returns:
new_logits: Tensor with the same shape as logits.
"""
# Compute probabilities from logits
probs = F.softmax(logits, dim=-1)
# Sort probabilities (descending) along the vocabulary dimension.
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
# Compute the cumulative sum along the sorted probabilities.
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Create a mask: True for tokens to keep.
# We keep tokens until cumulative_probs <= top_p.
keep_mask = cumulative_probs <= top_p
# Ensure that at least one token is kept per example: if none are kept, keep the top one.
# Here we check along the vocab dimension.
no_token_kept = keep_mask.sum(dim=-1, keepdim=True) == 0
if no_token_kept.any():
# For positions where no token was kept, set the first token (highest probability) to True.
# Note: torch.scatter_ returns a modified tensor.
# We create a tensor of zeros (False) and then scatter True into the first column.
fix_mask = torch.zeros_like(keep_mask, dtype=torch.bool)
fix_mask.scatter_(-1, torch.zeros_like(keep_mask[..., :1], dtype=torch.long), True)
keep_mask = torch.where(no_token_kept, fix_mask, keep_mask)
# Now, create new logits: copy the original logits.
new_logits = logits.clone()
# For tokens that are not kept (i.e. where keep_mask is False), set their logit to log(epsilon)
new_logits[~keep_mask] = torch.log(torch.tensor(epsilon, device=logits.device, dtype=logits.dtype))
return new_logits
class Mosaic(object):
def __init__(
self,
model_name_or_paths: List[str],
use_bfloat16: bool = True,
max_token_observed: int = 512,
unigram: Optional[str] = None,
custom_config: Optional[List[bool]] = None,
stupid_mode: bool = False,
one_model_mode: bool = False,
# new optional argument: preloaded models dict
loaded_models: Optional[Dict[str, AutoModelForCausalLM]] = None,
) -> None:
"""
If `loaded_models` is provided, re-use any entries matching
model_name_or_paths; otherwise load and optionally register
into that dict.
"""
self.models = []
# ensure we have a dict to cache into if passed
cache = loaded_models if loaded_models is not None else {}
for model_name_or_path in model_name_or_paths:
# reuse if already loaded
if loaded_models is not None and model_name_or_path in cache:
model = cache[model_name_or_path]
else:
print("Reloading a model was necessary, you probably messed up.")
# load from pre-trained hub or path
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16 if use_bfloat16 else torch.float32,
)
model.eval()
# cache for reuse
if loaded_models is not None:
cache[model_name_or_path] = model
self.models.append(model)
print(f"Loaded model: {model_name_or_path}")
# store optional references
self.loaded_models = cache
self.one_model_mode = one_model_mode
if stupid_mode:
self.max_iters = 0
else:
self.max_iters = 1000
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_paths[-1])
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.max_token_observed = max_token_observed
self.nb_models = len(self.models)
self.unigram_path = unigram
if custom_config is None:
custom_config = [False] * self.nb_models
self.custom_config = custom_config
def _tokenize(self, batch: list[str]) -> transformers.BatchEncoding:
encodings = self.tokenizer(
batch,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_token_observed,
return_token_type_ids=False)
return encodings
def trim_logits(self, logits, max_length=32000):
# Check the shape of the logits tensor
if logits.shape[2] > max_length:
# Slice the tensor to keep only the first max_length elements along the last dimension
logits = logits[:, :, :max_length]
return logits
@torch.inference_mode()
def _get_logits(self, encodings: transformers.BatchEncoding) -> List[torch.Tensor]:
# If one_model_mode is active, we simulate multiple models by applying top-p with different thresholds.
if self.one_model_mode:
# Compute base logits from the single model.
model = self.models[0]
device = next(model.parameters()).device
model_encodings = encodings.to(device)
base_logits = model(**model_encodings).logits
# Optionally trim logits:
# base_logits = self.trim_logits(base_logits)
# Define the top-p thresholds (e.g., four different values)
top_p_values = [0.7, 0.8, 0.9, 0.95]
# Epsilon value for non-selected tokens (you can adjust this if needed)
epsilon = 1e-10
logits_list = []
for top_p in top_p_values:
warped_logits = apply_top_p_with_epsilon(base_logits, top_p, epsilon)
logits_list.append(warped_logits)
else:
# Normal mode: use each model in self.models.
logits_list = []
for i, model in enumerate(self.models):
device = next(model.parameters()).device
model_encodings = encodings.to(device)
logits = model(**model_encodings).logits
# Optionally trim logits:
# logits = self.trim_logits(logits)
logits_list.append(logits)
if device.type == "cuda":
torch.cuda.synchronize(device)
if self.unigram_path:
batch_size, seq_len, voc_size = logits_list[0].shape
unigram_proba = torch.load(self.unigram_path)
unigram_proba += 1e-10
unigram_logits = torch.log(unigram_proba)
# Optionally center logits if needed:
logits = logits_list[0] - logits_list[0].mean(dim=-1, keepdim=True)
expanded_unigram_logits = unigram_logits.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, voc_size)
logits_list.append(expanded_unigram_logits)
return logits_list
def get_softmax_probabilities(self, input_text):
encodings = self._tokenize(input_text)
logits_list = self._get_logits(encodings)
probabilities_list = softmax_probabilities_all_models(logits_list)
return encodings, logits_list, probabilities_list
def compute_arimoto_torch(self, input_text, max_iters=1000):
encodings, logits_list, tensors_list = self.get_softmax_probabilities(input_text)
nb_models = len(tensors_list)
seq_len = len(encodings.input_ids[0])
voc_size = tensors_list[0].shape[-1]
device = tensors_list[0].device
# Move all tensors in tensors_list to the device of the first tensor
tensors_list = [tensor.to(device) for tensor in tensors_list]
# Stack all model predictions along a new dimension to form a (seq_len, nb_models, voc_size) tensor
probabilities_tensor = torch.stack([t[0] for t in tensors_list], dim=1).to(tensors_list[0].device)
# Run the Blahut-Arimoto algorithm on the entire batch
capacity, p = blahut_arimoto_torch(probabilities_tensor, max_iters=max_iters)
# Prepare the weighted sum tensor, initially zeros
weighted_sum_tensor = torch.zeros_like(tensors_list[0])
# Here, we need an additional mechanism if 'p' shapes or logic require different handling
# Assuming 'p' is now (seq_len, nb_models), apply weights to each model's output
for i in range(nb_models):
weighted_sum_tensor += p[:, i:i+1] * tensors_list[i]
return encodings, weighted_sum_tensor, tensors_list, p, logits_list
def compute_scores(self, input_text):
encodings, weighted_sum_tensor, probabilities_list, arimoto_weights, logits_list = self.compute_arimoto_torch(input_text, max_iters=self.max_iters)
log_ppl, ppl, nll = perplexity(encodings, weighted_sum_tensor)
ppl_list = perplexity_all_models(encodings, logits_list)
x_ppl_list = cross_entropy(weighted_sum_tensor, probabilities_list)
return log_ppl, x_ppl_list, arimoto_weights, nll, ppl_list
def compute_end_score(self, input_text):
encodings, weighted_sum_tensor, probabilities_list, arimoto_weights, logits_list = self.compute_arimoto_torch(input_text)
log_ppl, ppl, nll = perplexity(encodings, weighted_sum_tensor)
ppl_list = perplexity_all_models(encodings, logits_list)
x_ppl_list = cross_entropy(weighted_sum_tensor, probabilities_list)
log_ppl_value = log_ppl.item()
x_ppl_values = [x.item() for x in x_ppl_list]
final_score = log_ppl_value - x_ppl_values[0] #Ensure your "reference model" is given as first argument
return final_score
def perplexity(encodings, weighted_sum_tensor):
shifted_probabilities = weighted_sum_tensor[..., :-1, :].contiguous()
shifted_labels = encodings.input_ids[..., 1:].contiguous()
shifted_attention_mask = encodings.attention_mask[..., 1:].contiguous()
device = shifted_probabilities.device
# Ensure all tensors are moved to the same device
shifted_probabilities = shifted_probabilities.to(device)
shifted_labels = shifted_labels.to(device)
shifted_attention_mask = shifted_attention_mask.to(device)
actual_next_token_probabilities = torch.gather(shifted_probabilities, 2, shifted_labels.unsqueeze(-1)).squeeze(-1)
nll = -torch.log(actual_next_token_probabilities + 1e-12)
nll_masked = nll * shifted_attention_mask
# Calculate the average NLL per sequence, taking into account only the valid (non-padded) tokens
average_nll = torch.sum(nll_masked, dim=1) / torch.sum(shifted_attention_mask, dim=1)
# Calculate perplexity per sequence
perplexity = torch.exp(average_nll)
return average_nll, perplexity, nll_masked
def cross_entropy(weighted_sum_tensor, probabilities_list):
device = weighted_sum_tensor.device
x_ppl_list = []
# Compute log of weighted_sum_tensor outside the loop since it doesn't depend on m2_probabilities
log_M1 = torch.log(weighted_sum_tensor).to(device)
for m2_probabilities in probabilities_list:
m2_probabilities = m2_probabilities.to(device)
# Ensure m2_probabilities is correctly shaped for batch matrix multiplication
# log_M1 shape is already (batch_size, sequence_length, vocabulary_size)
# We need m2_probabilities in shape (batch_size, vocabulary_size, sequence_length) for bmm
m2_probabilities_transposed = m2_probabilities.transpose(1, 2)
# Perform batch matrix multiplication
# Resulting shape: (batch_size, sequence_length, sequence_length)
# We sum over the vocabulary dimension, effectively computing the dot product for each sequence position
dot_products = torch.bmm(log_M1, m2_probabilities_transposed)
# Since we're interested in the diagonal (dot products of corresponding vectors), we extract it
# The diagonal for each item in the batch gives us the dot products we're interested in
# torch.diagonal doesn't support batched operations directly, so we need to workaround
dot_products_diagonal = torch.einsum('bii->bi', dot_products) # Using einsum to extract diagonals for batch
# Compute the mean of the dot_products_diagonal across the sequence dimension
# This gives us the average dot product per sequence, which is then negated
x_ppl = -torch.mean(dot_products_diagonal, dim=1)
x_ppl_list.append(x_ppl)
x_ppl_tensor = torch.stack(x_ppl_list)
return x_ppl_list #, x_ppl_tensor
def softmax_probabilities_all_models(logits_list: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Calculates the softmax probabilities for the entire sequence of tokens for each model.
Parameters:
- logits_list: List[torch.Tensor]
A list containing the logits tensor for each model.
Returns:
- List[torch.Tensor]: A list of tensors, where each tensor is the softmax probabilities
for one model across the entire sequence of tokens.
"""
softmax_fn = torch.nn.Softmax(dim=-1)
probabilities_list = []
for logits in logits_list:
# Calculate softmax probabilities across the vocabulary for each token position
softmax_probabilities = softmax_fn(logits)
probabilities_list.append(softmax_probabilities)
return probabilities_list
def perplexity_logits(encoding, logits):
# Ensure encoding tensors are moved to the same device as logits
device = logits.device
logits = torch.clamp(logits, min=-20, max=50)
encoding_input_ids = encoding.input_ids.to(device)
encoding_attention_mask = encoding.attention_mask.to(device)
ce_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
shifted_logits = logits[..., :-1, :].contiguous()
shifted_labels = encoding_input_ids[..., 1:].contiguous()
shifted_attention_mask = encoding_attention_mask[..., 1:].contiguous()
# Calculate Cross-Entropy loss
cross_entropy_loss = ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels)
# Apply attention mask
masked_ce_loss = cross_entropy_loss * shifted_attention_mask
# Calculate perplexity
ppl = masked_ce_loss.sum(1) / shifted_attention_mask.sum(1)
# Move result to CPU and convert to numpy for further processing if needed
ppl = ppl.to("cpu").float().numpy()
return ppl
def perplexity_all_models(encoding, logits_list):
ppl_list = []
for logits in logits_list:
ppl = perplexity_logits(encoding, logits)
ppl_list.append(ppl)
return ppl_list
def blahut_arimoto_torch(W, epsilon=1e-6, max_iters=1000):
"""
Batch-process Blahut-Arimoto using PyTorch for multiple sequences.
"""
seq_len, nb_models, voc_size = W.shape
p = torch.full((seq_len, nb_models), 1.0 / nb_models, device=W.device, dtype=W.dtype)
prod_exp = torch.ones((seq_len, nb_models), device=W.device, dtype=W.dtype)
for _ in range(max_iters):
# Calculate the marginal probabilities
sum_p_w = torch.bmm(p.unsqueeze(1), W).squeeze(1) # Resultant shape: (seq_len, voc_size)
# Calculate normalized probabilities
W_normalized = W / sum_p_w.unsqueeze(1) # Broadcasting to shape (seq_len, nb_models, voc_size)
# Avoid numerical issues with logarithms
W_normalized[W_normalized == 0] = torch.finfo(W.dtype).eps
log_term = torch.log(W_normalized)
log_term[torch.isnan(log_term) | torch.isinf(log_term)] = 0
# Compute product exponentials and update probabilities
prod_exp = torch.exp(torch.sum(W * log_term, axis=2)) # Sum across voc_size
p_new = (p * prod_exp) / torch.sum(p * prod_exp, dim=1, keepdim=True)
# Check convergence
if torch.max(torch.abs(p - p_new)) < epsilon:
break
p = p_new
# Compute channel capacity
capacity = torch.log(torch.sum(p * prod_exp, dim=1)) / torch.log(torch.tensor(2.0, device=W.device))
return capacity, p