File size: 543 Bytes
5eb7d19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import LlamaConfig


class LlamaWithFeatsEncoderConfig(LlamaConfig):
    model_type = "llama_with_feats_encoder"

    def __init__(self, feats_hidden_size=8, **kwargs):
        super().__init__(**kwargs)
        self.feats_hidden_size = feats_hidden_size

    def to_dict(self):
        """
        Serializes this instance to a Python dictionary.
        """
        output = super().to_dict()
        output["model_type"] = self.model_type
        output["feats_hidden_size"] = self.feats_hidden_size
        return output