Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import os | |
import sys | |
import json | |
import tempfile | |
import socket | |
import string | |
import random | |
import ruamel.yaml as yaml | |
import psutil | |
from colorama import Fore | |
from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO | |
def get_yml_content(file_path): | |
'''Load yaml file content''' | |
try: | |
with open(file_path, 'r') as file: | |
return yaml.load(file, Loader=yaml.Loader) | |
except yaml.scanner.ScannerError as err: | |
print_error('yaml file format error!') | |
print_error(err) | |
exit(1) | |
except Exception as exception: | |
print_error(exception) | |
exit(1) | |
def get_json_content(file_path): | |
'''Load json file content''' | |
try: | |
with open(file_path, 'r') as file: | |
return json.load(file) | |
except TypeError as err: | |
print_error('json file format error!') | |
print_error(err) | |
return None | |
def print_error(*content): | |
'''Print error information to screen''' | |
print(Fore.RED + ERROR_INFO + ' '.join([str(c) for c in content]) + Fore.RESET) | |
def print_green(*content): | |
'''Print information to screen in green''' | |
print(Fore.GREEN + ' '.join([str(c) for c in content]) + Fore.RESET) | |
def print_normal(*content): | |
'''Print error information to screen''' | |
print(NORMAL_INFO, *content) | |
def print_warning(*content): | |
'''Print warning information to screen''' | |
print(Fore.YELLOW + WARNING_INFO + ' '.join([str(c) for c in content]) + Fore.RESET) | |
def detect_process(pid): | |
'''Detect if a process is alive''' | |
try: | |
process = psutil.Process(pid) | |
return process.is_running() | |
except: | |
return False | |
def detect_port(port): | |
'''Detect if the port is used''' | |
socket_test = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
try: | |
socket_test.connect(('127.0.0.1', int(port))) | |
socket_test.close() | |
return True | |
except: | |
return False | |
def get_user(): | |
if sys.platform == 'win32': | |
return os.environ['USERNAME'] | |
else: | |
return os.environ['USER'] | |
def check_tensorboard_version(): | |
try: | |
import tensorboard | |
return tensorboard.__version__ | |
except: | |
print_error('import tensorboard error!') | |
exit(1) | |
def generate_temp_dir(): | |
'''generate a temp folder''' | |
def generate_folder_name(): | |
return os.path.join(tempfile.gettempdir(), 'nni', ''.join(random.sample(string.ascii_letters + string.digits, 8))) | |
temp_dir = generate_folder_name() | |
while os.path.exists(temp_dir): | |
temp_dir = generate_folder_name() | |
os.makedirs(temp_dir) | |
return temp_dir | |