Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import os | |
class TrainPlatform: | |
def __init__(self, save_dir): | |
pass | |
def report_scalar(self, name, value, iteration, group_name=None): | |
pass | |
def report_args(self, args, name): | |
pass | |
def close(self): | |
pass | |
class ClearmlPlatform(TrainPlatform): | |
def __init__(self, save_dir): | |
from clearml import Task | |
path, name = os.path.split(save_dir) | |
self.task = Task.init(project_name='motion_diffusion', | |
task_name=name, | |
output_uri=path) | |
self.logger = self.task.get_logger() | |
def report_scalar(self, name, value, iteration, group_name): | |
self.logger.report_scalar(title=group_name, series=name, iteration=iteration, value=value) | |
def report_args(self, args, name): | |
self.task.connect(args, name=name) | |
def close(self): | |
self.task.close() | |
class TensorboardPlatform(TrainPlatform): | |
def __init__(self, save_dir): | |
from torch.utils.tensorboard import SummaryWriter | |
self.writer = SummaryWriter(log_dir=save_dir) | |
def report_scalar(self, name, value, iteration, group_name=None): | |
self.writer.add_scalar(f'{group_name}/{name}', value, iteration) | |
def close(self): | |
self.writer.close() | |
class NoPlatform(TrainPlatform): | |
def __init__(self, save_dir): | |
pass | |