File size: 1,088 Bytes
3a2aa34 e424d6b 3a2aa34 e424d6b 3a2aa34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from transformers import PreTrainedModel
from .learner import PrefLearner
from .configuration_pal_b_rm import PAL_B_Config
class PAL_B_RM_opt(PreTrainedModel):
config_class = PAL_B_Config
def __init__(self, config):
super().__init__(config)
self.model = PrefLearner(
d_hid=config.d_hid,
d_pref=config.d_pref,
k=config.k,
llm_name=config.llm_name,
pref_learner_type=config.pref_learner_type,
proj_arch=config.proj_arch,
initializer_type=config.initializer_type,
is_expectation_norm_init=config.is_expectation_norm_init,
sfx_type=config.sfx_type,
sfx_temperature=config.sfx_temperature,
is_temperature_learnable=config.is_temperature_learnable,
is_gumbel_hard=config.is_gumbel_hard,
uids=config.uids,
)
if config.uids is not None:
self.model.user_learner.init_weight(config.uids)
def forward(self, x):
logits = self.model(x)
return {'logits': logits}
|