|
from transformers import PreTrainedModel |
|
from .learner import PrefLearner |
|
from .configuration_pal_b_rm import PAL_B_Config |
|
|
|
class PAL_B_RM_opt(PreTrainedModel): |
|
config_class = PAL_B_Config |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = PrefLearner( |
|
d_hid=config.d_hid, |
|
d_pref=config.d_pref, |
|
k=config.k, |
|
llm_name=config.llm_name, |
|
pref_learner_type=config.pref_learner_type, |
|
proj_arch=config.proj_arch, |
|
initializer_type=config.initializer_type, |
|
is_expectation_norm_init=config.is_expectation_norm_init, |
|
sfx_type=config.sfx_type, |
|
sfx_temperature=config.sfx_temperature, |
|
is_temperature_learnable=config.is_temperature_learnable, |
|
is_gumbel_hard=config.is_gumbel_hard, |
|
uids=config.uids, |
|
) |
|
if config.uids is not None: |
|
self.model.user_learner.init_weight(config.uids) |
|
|
|
def forward(self, x): |
|
logits = self.model(x) |
|
return {'logits': logits} |
|
|