#!/usr/bin/env python # -*-coding:utf-8 -*- ''' @Desc: This is the implementation of PAL-B ''' import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoConfig from .connector import Connector from .tensor_initializer import TensorInitializer from .custom_sfx import CustomSoftMax from .itemLearner import ItemLearner from .userLearner import UserLearner from collections import defaultdict from typing import Literal, Optional, Tuple import logging logger = logging.getLogger(__name__) class BasePrefLearner(nn.Module): def __init__( self, d_hid: int, d_pref: int, k: int, llm_name: str, pref_learner_type: Literal["dist","dist_normalization","angle","norm","dist_logistic","angle_hinge"], proj_arch: str, initializer_type: Literal["gaussian"], is_expectation_norm_init: bool, # the tensor initialization parameters sfx_type: Literal["gumbel_softmax", "softmax"], sfx_temperature: float, is_temperature_learnable: bool, is_gumbel_hard: Optional[bool]=None, is_partition: bool=False, seed: int=42, **kwargs ): super().__init__() self.pref_learner_type = pref_learner_type self.is_temperature_learnable = is_temperature_learnable # init all necessary modules model_config = AutoConfig.from_pretrained(llm_name) self.llm = AutoModel.from_pretrained(llm_name,from_tf=bool(".ckpt" in llm_name),config=model_config) self.tensor_initializer = TensorInitializer(initializer_type, seed, is_expectation_norm_init=is_expectation_norm_init) self.projector_f = Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref) self.projectors_gk = [Connector(cnct_arch=proj_arch,in_dims=d_hid,out_dims=d_pref) for _ in range(k)] self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.softmax_w = CustomSoftMax(sfx_type=sfx_type, temperature=sfx_temperature, is_temperature_learnable=is_temperature_learnable, is_gumbel_hard=is_gumbel_hard) self.item_learner = ItemLearner( llm = self.llm, projector=self.projector_f, ) self.is_partition = is_partition self.user_learner = UserLearner(k=k, llm=self.llm, projectors=self.projectors_gk, softmax=self.softmax_w, is_partition=is_partition) logger.critical('🛑 Remember to call update_trainable_params() after the model is initialized.') def update_trainable_params(self, fix_modules: Tuple[str,...]=()): # capture params self.trainable_params = defaultdict(list) if "llm" not in fix_modules: self.trainable_params["llm"] = self.llm.parameters() else: self.llm.eval() if "itemLearnerProjector" not in fix_modules: self.trainable_params["projector_f"].extend(self.item_learner.projector.parameters()) if "userLearnerProjector" not in fix_modules: self.trainable_params["projectors_gk"].extend(list(self.user_learner.projectors.parameters())) if "W" not in fix_modules: self.trainable_params["W"] = self.user_learner.W.parameters() if self.pref_learner_type in ["angle","dist_logistic"] and "logit_scale" not in fix_modules: self.trainable_params["logit_scale"] = self.logit_scale if self.is_temperature_learnable and "temperature" not in fix_modules: self.trainable_params["temperature"] = self.softmax_w.temperature def map_to_pref_embedding_space(self, x, rm_cached=None): # ( # uid, # { # 'input_ids': prompt_input_ids,\ # 'attention_mask': prompt_attention_mask, # },\ # { # 'input_ids': eval_input_ids,\ # 'attention_mask': eval_attention_mask,\ # }) uid, prompt, items = x if rm_cached is None: items_prime = self.item_learner(items) prompt_prime = self.user_learner(uid, prompt) return items_prime, prompt_prime else: items_prime, rm_cached = self.item_learner(items, rm_cached) prompt_prime, rm_cached = self.user_learner(uid, prompt, rm_cached) return items_prime, prompt_prime, rm_cached class PrefLearner(BasePrefLearner): # def __init__(self,*args, **kwargs): super().__init__(*args, **kwargs) def specify_user_ids(self, uid): # personalize the model for a specific user self.uid = uid def forward(self, x, rm_cached=None): assert self.uid is not None, "Please specify the user id first by calling specify_user_ids() to personalize the reward model" prompt, items = x if rm_cached is None: items_prime, prompt_prime = self.map_to_pref_embedding_space((self.uid, prompt, items)) else: items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((self.uid, prompt, items), rm_cached) # logger.critical(f"{items_prime[0]=}") # logger.critical(f"{prompt_prime[0]=}") # logger.critical(f"{items_prime.shape=}") # logger.critical(f"{prompt_prime.shape=}") if self.pref_learner_type == 'angle': # NOTICE: here we implement the "last token only" version of PAL-B prompt_last_prime = prompt_prime[:, -1, :] prompt_last_prime = prompt_last_prime.unsqueeze(1) prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True) items_last_prime = items_prime[:, -1, :] items_last_prime = items_last_prime.unsqueeze(1) items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True) logit_scale = self.logit_scale.exp() clamped_logit_scale = torch.clamp(logit_scale, max=100) # logger.critical(f"{prompt_last_prime.shape=}") # logger.critical(f"{items_last_prime.shape=}") sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale # (bs, max_token_length) if rm_cached is None: return sim_score else: return sim_score, rm_cached else: raise NotImplementedError