File size: 2,347 Bytes
c412427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import json

class ModelArgs:
    """

    A class to parse and store model configuration from a JSON file.

    """
    def __init__(self, model_type:str, config_path:str):
        """

        Initialize ModelArgs with configuration from a JSON file.



        Args:

            config_path (str): Path to the JSON configuration file.



        Raises:

            AssertionError: If the JSON content is invalid or has missing keys.

        """
        # Load JSON file
        with open(config_path, 'r') as file:
            config = json.load(file)
        
        # Validate and assign attributes
        self.model_type = model_type.lower()
        assert self.model_type in ['s2s', 's2sattention', 'transformer'], \
            "Supported model_type values are ['s2s', 's2sAttention', 'transformer']."
        
        self.dim_embed = config.get("dim_embed")
        assert isinstance(self.dim_embed, int), "dim_embed must be an integer."
        
        self.dim_model = config.get("dim_model")
        assert isinstance(self.dim_model, int), "dim_model must be an integer."
        
        self.dim_feedforward = config.get("dim_feedforward")
        assert isinstance(self.dim_feedforward, int), "dim_feedforward must be an integer."
        
        self.num_layers = config.get("num_layers")
        assert isinstance(self.num_layers, int), "num_layers must be an integer."
        
        self.dropout = config.get("dropout")
        assert isinstance(self.dropout, float), "dropout must be a float."

        self.maxlen = config.get("maxlen")
        assert isinstance(self.maxlen, int), "maxlen must be an integer."

        self.flash_attention = config.get("flash_attention")
        assert isinstance(self.flash_attention, bool), "flash_attention must be a boolean."

    def __repr__(self):
        return (f"ModelArgs(\n" +
                f"model_type={self.model_type},\n" +
                f"dim_embed={self.dim_embed},\n" +
                f"dim_model={self.dim_model},\n" +
                f"dim_feedforward={self.dim_feedforward},\n" +
                f"num_layers={self.num_layers},\n" +
                f"dropout={self.dropout},\n" +
                f"maxlen={self.maxlen},\n" +
                f"flash_attention={self.flash_attention}\n" +
                ")")