from transformers import PreTrainedModel from .learner import PrefLearner from .configuration_pal_b_rm import PAL_B_Config class PAL_B_RM_opt(PreTrainedModel): config_class = PAL_B_Config def __init__(self, config): super().__init__(config) self.model = PrefLearner( d_hid=config.d_hid, d_pref=config.d_pref, k=config.k, llm_name=config.llm_name, pref_learner_type=config.pref_learner_type, proj_arch=config.proj_arch, initializer_type=config.initializer_type, is_expectation_norm_init=config.is_expectation_norm_init, sfx_type=config.sfx_type, sfx_temperature=config.sfx_temperature, is_temperature_learnable=config.is_temperature_learnable, is_gumbel_hard=config.is_gumbel_hard, uids=config.uids, ) if config.uids is not None: self.model.user_learner.init_weight(config.uids) def forward(self, x): logits = self.model(x) return {'logits': logits}