|
import importlib |
|
|
|
|
|
import logging |
|
|
|
logger = logging.getLogger("pytorch_lightning") |
|
|
|
from pytorch_lightning.utilities.rank_zero import ( |
|
rank_zero_debug, |
|
rank_zero_info, |
|
rank_zero_only, |
|
) |
|
|
|
|
|
def find(cls_string): |
|
module_string = ".".join(cls_string.split(".")[:-1]) |
|
cls_name = cls_string.split(".")[-1] |
|
module = importlib.import_module(module_string, package=None) |
|
cls = getattr(module, cls_name) |
|
return cls |
|
|
|
|
|
debug = rank_zero_debug |
|
info = rank_zero_info |
|
|
|
|
|
@rank_zero_only |
|
def warn(*args, **kwargs): |
|
logger.warn(*args, **kwargs) |
|
|