|
import json
|
|
|
|
class ModelArgs:
|
|
"""
|
|
A class to parse and store model configuration from a JSON file.
|
|
"""
|
|
def __init__(self, model_type:str, config_path:str):
|
|
"""
|
|
Initialize ModelArgs with configuration from a JSON file.
|
|
|
|
Args:
|
|
config_path (str): Path to the JSON configuration file.
|
|
|
|
Raises:
|
|
AssertionError: If the JSON content is invalid or has missing keys.
|
|
"""
|
|
|
|
with open(config_path, 'r') as file:
|
|
config = json.load(file)
|
|
|
|
|
|
self.model_type = model_type.lower()
|
|
assert self.model_type in ['s2s', 's2sattention', 'transformer'], \
|
|
"Supported model_type values are ['s2s', 's2sAttention', 'transformer']."
|
|
|
|
self.dim_embed = config.get("dim_embed")
|
|
assert isinstance(self.dim_embed, int), "dim_embed must be an integer."
|
|
|
|
self.dim_model = config.get("dim_model")
|
|
assert isinstance(self.dim_model, int), "dim_model must be an integer."
|
|
|
|
self.dim_feedforward = config.get("dim_feedforward")
|
|
assert isinstance(self.dim_feedforward, int), "dim_feedforward must be an integer."
|
|
|
|
self.num_layers = config.get("num_layers")
|
|
assert isinstance(self.num_layers, int), "num_layers must be an integer."
|
|
|
|
self.dropout = config.get("dropout")
|
|
assert isinstance(self.dropout, float), "dropout must be a float."
|
|
|
|
self.maxlen = config.get("maxlen")
|
|
assert isinstance(self.maxlen, int), "maxlen must be an integer."
|
|
|
|
self.flash_attention = config.get("flash_attention")
|
|
assert isinstance(self.flash_attention, bool), "flash_attention must be a boolean."
|
|
|
|
def __repr__(self):
|
|
return (f"ModelArgs(\n" +
|
|
f"model_type={self.model_type},\n" +
|
|
f"dim_embed={self.dim_embed},\n" +
|
|
f"dim_model={self.dim_model},\n" +
|
|
f"dim_feedforward={self.dim_feedforward},\n" +
|
|
f"num_layers={self.num_layers},\n" +
|
|
f"dropout={self.dropout},\n" +
|
|
f"maxlen={self.maxlen},\n" +
|
|
f"flash_attention={self.flash_attention}\n" +
|
|
")") |