|
from typing import OrderedDict, Optional |
|
from PIL import Image |
|
|
|
from toolkit.config_modules import LoggingConfig |
|
|
|
|
|
|
|
class EmptyLogger: |
|
def __init__(self, *args, **kwargs) -> None: |
|
pass |
|
|
|
|
|
def start(self): |
|
pass |
|
|
|
|
|
def log(self, *args, **kwargs): |
|
pass |
|
|
|
|
|
def commit(self, step: Optional[int] = None): |
|
pass |
|
|
|
|
|
def log_image(self, *args, **kwargs): |
|
pass |
|
|
|
|
|
def finish(self): |
|
pass |
|
|
|
|
|
|
|
class WandbLogger(EmptyLogger): |
|
def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None: |
|
self.project = project |
|
self.run_name = run_name |
|
self.config = config |
|
|
|
def start(self): |
|
try: |
|
import wandb |
|
except ImportError: |
|
raise ImportError("Failed to import wandb. Please install wandb by running `pip install wandb`") |
|
|
|
|
|
run = wandb.init(project=self.project, name=self.run_name, config=self.config) |
|
self.run = run |
|
self._log = wandb.log |
|
self._image = wandb.Image |
|
|
|
def log(self, *args, **kwargs): |
|
|
|
|
|
self._log(*args, **kwargs, commit=False) |
|
|
|
def commit(self, step: Optional[int] = None): |
|
|
|
|
|
self._log({}, step=step, commit=True) |
|
|
|
def log_image( |
|
self, |
|
image: Image, |
|
id, |
|
caption: str | None = None, |
|
*args, |
|
**kwargs, |
|
): |
|
|
|
image = self._image(image, caption=caption, *args, **kwargs) |
|
self._log({f"sample_{id}": image}, commit=False) |
|
|
|
def finish(self): |
|
self.run.finish() |
|
|
|
|
|
def create_logger(logging_config: LoggingConfig, all_config: OrderedDict): |
|
if logging_config.use_wandb: |
|
project_name = logging_config.project_name |
|
run_name = logging_config.run_name |
|
return WandbLogger(project=project_name, run_name=run_name, config=all_config) |
|
else: |
|
return EmptyLogger() |
|
|