File size: 8,002 Bytes
8c31d70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import os
import pkgutil
import sys
from dataclasses import fields as dataclass_fields
from dataclasses import is_dataclass
from typing import Any, Dict, Optional

import attr
import attrs
from hydra import compose, initialize
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, OmegaConf

from .log import log
from .config import Config
from .inference import *


def is_attrs_or_dataclass(obj) -> bool:
    """
    Check if the object is an instance of an attrs class or a dataclass.

    Args:
        obj: The object to check.

    Returns:
        bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
    """
    return is_dataclass(obj) or attr.has(type(obj))


def get_fields(obj):
    """
    Get the fields of an attrs class or a dataclass.

    Args:
        obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.

    Returns:
        list: A list of field names.

    Raises:
        ValueError: If the object is neither an attrs class nor a dataclass.
    """
    if is_dataclass(obj):
        return [field.name for field in dataclass_fields(obj)]
    elif attr.has(type(obj)):
        return [field.name for field in attr.fields(type(obj))]
    else:
        raise ValueError("The object is neither an attrs class nor a dataclass.")


def override(config: Config, overrides: Optional[list[str]] = None) -> Config:
    """
    :param config: the instance of class `Config` (usually from `make_config`)
    :param overrides: list of overrides for config
    :return: the composed instance of class `Config`
    """
    # Store the class of the config for reconstruction after overriding.
    # config_class = type(config)

    # Convert Config object to a DictConfig object
    config_dict = attrs.asdict(config)
    config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
    # Enforce "--" separator between the script arguments and overriding configs.
    if overrides:
        if overrides[0] != "--":
            raise ValueError('Hydra config overrides must be separated with a "--" token.')
        overrides = overrides[1:]
    # Use Hydra to handle overrides
    cs = ConfigStore.instance()
    cs.store(name="config", node=config_omegaconf)
    with initialize(version_base=None):
        config_omegaconf = compose(config_name="config", overrides=overrides)
        OmegaConf.resolve(config_omegaconf)

    def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
        """
        Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data

        Args:
            ref_instance: The reference instance to determine the type and fields when needed
            kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data

        Returns:
            Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data

        Raises:
            AssertionError: If the fields do not match or if extra keys are found.
            Exception: If there is an error constructing the new instance.
        """
        is_type = is_attrs_or_dataclass(ref_instance)
        if not is_type:
            return kwargs
        else:
            ref_fields = set(get_fields(ref_instance))
            assert isinstance(kwargs, dict) or isinstance(
                kwargs, DictConfig
            ), "kwargs must be a dictionary or a DictConfig"
            keys = set(kwargs.keys())

            # ref_fields must equal to or include all keys
            extra_keys = keys - ref_fields
            assert ref_fields == keys or keys.issubset(
                ref_fields
            ), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"

            resolved_kwargs: Dict[str, Any] = {}
            for f in keys:
                resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
            try:
                new_instance = type(ref_instance)(**resolved_kwargs)
            except Exception as e:
                log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
                log.error(e)
                raise e
            return new_instance

    config = config_from_dict(config, config_omegaconf)

    return config


def get_config_module(config_file: str) -> str:
    if not config_file.endswith(".py"):
        log.error("Config file cannot be specified as module.")
        log.error("Please provide the path to the Python config file (relative to the Cosmos root).")
    assert os.path.isfile(config_file), f"Cosmos config file ({config_file}) not found."
    # Convert to importable module format.
    config_module = config_file.replace("/", ".").replace(".py", "")
    return config_module


def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None:
    """
    Import all modules from the specified package path recursively.

    This function is typically used in conjunction with Hydra to ensure that all modules
    within a specified package are imported, which is necessary for registering configurations.

    Example usage:
    ```python
    import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True, skip_underscore=False)
    ```

    Args:
        package_path (str): The dotted path to the package from which to import all modules.
        reload (bool): Flag to determine whether to reload modules if they're already imported.
        skip_underscore (bool): If True, skips importing modules that start with an underscore.
    """
    return  # we do not use this function
    log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
    package = importlib.import_module(package_path)
    package_directory = package.__path__

    def import_modules_recursively(directory: str, prefix: str) -> None:
        """
        Recursively imports or reloads all modules in the given directory.

        Args:
            directory (str): The file system path to the current package directory.
            prefix (str): The module prefix (e.g., 'cosmos1.models.diffusion.config').
        """
        for _, module_name, is_pkg in pkgutil.iter_modules([directory]):
            if skip_underscore and module_name.startswith("_"):
                log.debug(f"Skipping module {module_name} as it starts with an underscore")
                continue

            full_module_name = f"{prefix}.{module_name}"
            log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}")

            if full_module_name in sys.modules and reload:
                importlib.reload(sys.modules[full_module_name])
            else:
                importlib.import_module(full_module_name)

            if is_pkg:
                sub_package_directory = os.path.join(directory, module_name)
                import_modules_recursively(sub_package_directory, full_module_name)

    for directory in package_directory:
        import_modules_recursively(directory, package_path)