TheDemond's picture
Upload 9 files
c412427 verified
import json
class ModelArgs:
"""
A class to parse and store model configuration from a JSON file.
"""
def __init__(self, model_type:str, config_path:str):
"""
Initialize ModelArgs with configuration from a JSON file.
Args:
config_path (str): Path to the JSON configuration file.
Raises:
AssertionError: If the JSON content is invalid or has missing keys.
"""
# Load JSON file
with open(config_path, 'r') as file:
config = json.load(file)
# Validate and assign attributes
self.model_type = model_type.lower()
assert self.model_type in ['s2s', 's2sattention', 'transformer'], \
"Supported model_type values are ['s2s', 's2sAttention', 'transformer']."
self.dim_embed = config.get("dim_embed")
assert isinstance(self.dim_embed, int), "dim_embed must be an integer."
self.dim_model = config.get("dim_model")
assert isinstance(self.dim_model, int), "dim_model must be an integer."
self.dim_feedforward = config.get("dim_feedforward")
assert isinstance(self.dim_feedforward, int), "dim_feedforward must be an integer."
self.num_layers = config.get("num_layers")
assert isinstance(self.num_layers, int), "num_layers must be an integer."
self.dropout = config.get("dropout")
assert isinstance(self.dropout, float), "dropout must be a float."
self.maxlen = config.get("maxlen")
assert isinstance(self.maxlen, int), "maxlen must be an integer."
self.flash_attention = config.get("flash_attention")
assert isinstance(self.flash_attention, bool), "flash_attention must be a boolean."
def __repr__(self):
return (f"ModelArgs(\n" +
f"model_type={self.model_type},\n" +
f"dim_embed={self.dim_embed},\n" +
f"dim_model={self.dim_model},\n" +
f"dim_feedforward={self.dim_feedforward},\n" +
f"num_layers={self.num_layers},\n" +
f"dropout={self.dropout},\n" +
f"maxlen={self.maxlen},\n" +
f"flash_attention={self.flash_attention}\n" +
")")