File size: 686 Bytes
b69413a
2ceedc4
b69413a
 
 
 
 
 
 
 
 
2ceedc4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers.models.gpt_neox import GPTNeoXPreTrainedModel, GPTNeoXModel
from transformers import PreTrainedTokenizerBase
from .modeling_measurement_pred import MeasurementPredictorMixin
from .configuration_gpt_neox_measurement_pred import GPTNeoXMeasurementPredictorConfig

class GPTNeoXMeasurementPredictor(GPTNeoXPreTrainedModel, MeasurementPredictorMixin):
    config_class = GPTNeoXMeasurementPredictorConfig

    def __init__(self, config):
        super().__init__(config)
        self.gpt_neox = GPTNeoXModel(config)
        self.post_init()
    
    def set_pad_token(self, tokenizer: PreTrainedTokenizerBase):
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})