pal-b-large-opt-350m / modeling_pal_b_rm.py
daiweichen's picture
Upload PAL_B_RM_opt
e424d6b verified
from transformers import PreTrainedModel
from .learner import PrefLearner
from .configuration_pal_b_rm import PAL_B_Config
class PAL_B_RM_opt(PreTrainedModel):
config_class = PAL_B_Config
def __init__(self, config):
super().__init__(config)
self.model = PrefLearner(
d_hid=config.d_hid,
d_pref=config.d_pref,
k=config.k,
llm_name=config.llm_name,
pref_learner_type=config.pref_learner_type,
proj_arch=config.proj_arch,
initializer_type=config.initializer_type,
is_expectation_norm_init=config.is_expectation_norm_init,
sfx_type=config.sfx_type,
sfx_temperature=config.sfx_temperature,
is_temperature_learnable=config.is_temperature_learnable,
is_gumbel_hard=config.is_gumbel_hard,
uids=config.uids,
)
if config.uids is not None:
self.model.user_learner.init_weight(config.uids)
def forward(self, x):
logits = self.model(x)
return {'logits': logits}