File size: 543 Bytes
5eb7d19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from transformers import LlamaConfig
class LlamaWithFeatsEncoderConfig(LlamaConfig):
model_type = "llama_with_feats_encoder"
def __init__(self, feats_hidden_size=8, **kwargs):
super().__init__(**kwargs)
self.feats_hidden_size = feats_hidden_size
def to_dict(self):
"""
Serializes this instance to a Python dictionary.
"""
output = super().to_dict()
output["model_type"] = self.model_type
output["feats_hidden_size"] = self.feats_hidden_size
return output
|