File size: 8,002 Bytes
8c31d70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import pkgutil
import sys
from dataclasses import fields as dataclass_fields
from dataclasses import is_dataclass
from typing import Any, Dict, Optional
import attr
import attrs
from hydra import compose, initialize
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, OmegaConf
from .log import log
from .config import Config
from .inference import *
def is_attrs_or_dataclass(obj) -> bool:
"""
Check if the object is an instance of an attrs class or a dataclass.
Args:
obj: The object to check.
Returns:
bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
"""
return is_dataclass(obj) or attr.has(type(obj))
def get_fields(obj):
"""
Get the fields of an attrs class or a dataclass.
Args:
obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
Returns:
list: A list of field names.
Raises:
ValueError: If the object is neither an attrs class nor a dataclass.
"""
if is_dataclass(obj):
return [field.name for field in dataclass_fields(obj)]
elif attr.has(type(obj)):
return [field.name for field in attr.fields(type(obj))]
else:
raise ValueError("The object is neither an attrs class nor a dataclass.")
def override(config: Config, overrides: Optional[list[str]] = None) -> Config:
"""
:param config: the instance of class `Config` (usually from `make_config`)
:param overrides: list of overrides for config
:return: the composed instance of class `Config`
"""
# Store the class of the config for reconstruction after overriding.
# config_class = type(config)
# Convert Config object to a DictConfig object
config_dict = attrs.asdict(config)
config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
# Enforce "--" separator between the script arguments and overriding configs.
if overrides:
if overrides[0] != "--":
raise ValueError('Hydra config overrides must be separated with a "--" token.')
overrides = overrides[1:]
# Use Hydra to handle overrides
cs = ConfigStore.instance()
cs.store(name="config", node=config_omegaconf)
with initialize(version_base=None):
config_omegaconf = compose(config_name="config", overrides=overrides)
OmegaConf.resolve(config_omegaconf)
def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
"""
Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
Args:
ref_instance: The reference instance to determine the type and fields when needed
kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
Returns:
Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
Raises:
AssertionError: If the fields do not match or if extra keys are found.
Exception: If there is an error constructing the new instance.
"""
is_type = is_attrs_or_dataclass(ref_instance)
if not is_type:
return kwargs
else:
ref_fields = set(get_fields(ref_instance))
assert isinstance(kwargs, dict) or isinstance(
kwargs, DictConfig
), "kwargs must be a dictionary or a DictConfig"
keys = set(kwargs.keys())
# ref_fields must equal to or include all keys
extra_keys = keys - ref_fields
assert ref_fields == keys or keys.issubset(
ref_fields
), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
resolved_kwargs: Dict[str, Any] = {}
for f in keys:
resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
try:
new_instance = type(ref_instance)(**resolved_kwargs)
except Exception as e:
log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
log.error(e)
raise e
return new_instance
config = config_from_dict(config, config_omegaconf)
return config
def get_config_module(config_file: str) -> str:
if not config_file.endswith(".py"):
log.error("Config file cannot be specified as module.")
log.error("Please provide the path to the Python config file (relative to the Cosmos root).")
assert os.path.isfile(config_file), f"Cosmos config file ({config_file}) not found."
# Convert to importable module format.
config_module = config_file.replace("/", ".").replace(".py", "")
return config_module
def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None:
"""
Import all modules from the specified package path recursively.
This function is typically used in conjunction with Hydra to ensure that all modules
within a specified package are imported, which is necessary for registering configurations.
Example usage:
```python
import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True, skip_underscore=False)
```
Args:
package_path (str): The dotted path to the package from which to import all modules.
reload (bool): Flag to determine whether to reload modules if they're already imported.
skip_underscore (bool): If True, skips importing modules that start with an underscore.
"""
return # we do not use this function
log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
package = importlib.import_module(package_path)
package_directory = package.__path__
def import_modules_recursively(directory: str, prefix: str) -> None:
"""
Recursively imports or reloads all modules in the given directory.
Args:
directory (str): The file system path to the current package directory.
prefix (str): The module prefix (e.g., 'cosmos1.models.diffusion.config').
"""
for _, module_name, is_pkg in pkgutil.iter_modules([directory]):
if skip_underscore and module_name.startswith("_"):
log.debug(f"Skipping module {module_name} as it starts with an underscore")
continue
full_module_name = f"{prefix}.{module_name}"
log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}")
if full_module_name in sys.modules and reload:
importlib.reload(sys.modules[full_module_name])
else:
importlib.import_module(full_module_name)
if is_pkg:
sub_package_directory = os.path.join(directory, module_name)
import_modules_recursively(sub_package_directory, full_module_name)
for directory in package_directory:
import_modules_recursively(directory, package_path)
|