|
import sys |
|
import os |
|
from toolkit.accelerator import get_accelerator |
|
|
|
|
|
def print_acc(*args, **kwargs): |
|
if get_accelerator().is_local_main_process: |
|
print(*args, **kwargs) |
|
|
|
|
|
class Logger: |
|
def __init__(self, filename): |
|
self.terminal = sys.stdout |
|
self.log = open(filename, 'a') |
|
|
|
def write(self, message): |
|
self.terminal.write(message) |
|
self.log.write(message) |
|
self.log.flush() |
|
|
|
def flush(self): |
|
self.terminal.flush() |
|
self.log.flush() |
|
|
|
|
|
def setup_log_to_file(filename): |
|
if get_accelerator().is_local_main_process: |
|
if not os.path.exists(os.path.dirname(filename)): |
|
os.makedirs(os.path.dirname(filename)) |
|
sys.stdout = Logger(filename) |
|
sys.stderr = Logger(filename) |
|
|