from dataclasses import dataclass from typing import Optional, List, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from transformers import Gemma2Model, Gemma2PreTrainedModel, Gemma2ForSequenceClassification, Gemma2Config from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING from transformers.utils import ModelOutput from transformers.utils import add_start_docstrings_to_model_forward import numpy as np from os.path import join as pjoin class GatingNetwork(nn.Module): """ Gating Network: A simple MLP with softmax output and temperature scaling This network learns to combine multiple reward objectives based on the input context """ def __init__( self, in_features: int, out_features: int, bias: bool = True, temperature: float = 10, logit_scale: float = 1.0, hidden_dim: int = 1024, n_hidden: int = 3, dropout: float = 0.2, ): super().__init__() self.temperature = temperature self.logit_scale = nn.Parameter(torch.ones(1) * logit_scale) layers = [] dropout_rate = dropout for i in range(n_hidden): layers.append(nn.Linear(in_features, hidden_dim, bias=False)) # for BN #nn.init.kaiming_normal_(layers[-1].weight, mode='fan_in', nonlinearity='relu') layers.append(nn.ReLU()) layers.append(nn.BatchNorm1d(hidden_dim)) if dropout_rate > 0 and i < n_hidden - 1: # no dropout before last layer for more stability and precision layers.append(nn.Dropout(dropout_rate)) in_features = hidden_dim layers.append(nn.Linear(in_features, out_features, bias=bias)) self.layers = nn.ModuleList(layers) # print("Gating network layers:", self.layers) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: orig_shape = x.shape x = x.reshape((-1, x.shape[-1])) for i, layer in enumerate(self.layers): x = layer(x) x = F.softmax(x / self.temperature, dim=1) x = x.reshape([s for s in orig_shape[:-1]] + [x.shape[-1]]) return x * self.logit_scale # Gemma2 token IDs of "\nmodel\n" token_pattern = [107, 108, 106, 2516, 108] def find_token_for_gating(lst, ): """Find the last occurrence of a token_pattern in a list.""" token_pattern_len = len(token_pattern) search_end = len(lst) for j in range(search_end - token_pattern_len, -1, -1): if lst[j:j + token_pattern_len] == token_pattern: return j raise ValueError("Token pattern not found in the list.") @dataclass class CustomOutput(ModelOutput): """ Base class for outputs of sentence classification models. Args: hidden_state (`Tuple[torch.FloatTensor]` of length `config.num_hidden_layers`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. prompt_embedding (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): The embeddings of the prompt tokens. gating_output (`torch.FloatTensor` of shape `(batch_size, config.num_objectives)`): The logits for the gating network. score (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): The final reward score. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): Same as score """ reward_quantiles: torch.FloatTensor = None rewards: torch.FloatTensor = None gating_output: Optional[torch.FloatTensor] = None score: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None class Gemma2ForQuantileSequenceClassification(Gemma2PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = Gemma2Model(config) # self.model = Gemma2Model(config).to(torch.bfloat16) config_dict = config.to_dict() self.num_objectives = config_dict.get("num_objectives", 5) self.num_quantiles = config_dict.get("num_quantiles", 19) self.quantiles = torch.linspace(0., 1., self.num_quantiles + 2)[1:-1] self.regression_layer = nn.Linear(config.hidden_size, self.num_quantiles * self.num_objectives, bias=False) self.post_init() num_objectives = 5 # Initialize weights and apply final processing self.gating = GatingNetwork(config.hidden_size, self.num_objectives, temperature=config_dict.get("gating_temperature", 1), hidden_dim=config_dict.get("gating_hidden_dim", 1024), n_hidden=config_dict.get("gating_n_hidden", 3)) def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CustomOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids.shape[0] == 1 and len(input_ids.shape) == 2 and input_ids[0,0] == input_ids[0,1] == 2: input_ids = input_ids[:, 1:] if attention_mask is not None: attention_mask = attention_mask[:, 1:] transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(hidden_states.device) else: sequence_lengths = -1 dummy_iterator = torch.arange(batch_size, device=hidden_states.device) last_hidden_states = hidden_states[dummy_iterator, sequence_lengths] assert last_hidden_states.shape == (batch_size, self.config.hidden_size) rewards = self.regression_layer(last_hidden_states) rewards = rewards.reshape(-1, self.num_objectives, self.num_quantiles) gating_token_positions = [find_token_for_gating(ids.tolist()) for ids in input_ids] prompt_embedding = hidden_states[dummy_iterator, gating_token_positions, :] gating_output = self.gating(prompt_embedding) # [B, num_objectives, num_quantiles, ] reward_quantiles = torch.mean( rewards * gating_output.unsqueeze(-1).repeat(1, 1, self.num_quantiles), dim=1) rewards_expectation = rewards.mean(dim=2) score = torch.sum(rewards_expectation.float() * gating_output.float(), dim=-1, keepdim=True) return CustomOutput( reward_quantiles=reward_quantiles, rewards=rewards_expectation, gating_output=gating_output, score=score, logits=score, )