qinfeng722's picture
Upload 322 files
5caedb4 verified
raw
history blame
5.96 kB
import logging
import os
import pickle
import random
import zipfile
from typing import Any, Optional
import numpy as np
import psutil
import torch
logger = logging.getLogger(__name__)
def set_seed(seed: int = 1234) -> None:
"""
Sets the random seed for various Python libraries to ensure reproducibility
of results across different runs.
Args:
seed (int, optional): seed value. Defaults to 1234.
"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def check_metric(cfg):
"""
Checks if the metric is set to GPT and if the OpenAI API key is set.
If not, sets the metric to BLEU and logs a warning.
"""
if "GPT" in cfg.prediction.metric and os.getenv("OPENAI_API_KEY", "") == "":
logger.warning("No OpenAI API Key set. Setting metric to BLEU. ")
cfg.prediction.metric = "BLEU"
return cfg
def kill_child_processes(current_pid: int, exclude=None) -> bool:
"""
Killing all child processes of the current process.
Optionally, excludes one PID
Args:
current_pid: current process id
exclude: process id to exclude
Returns:
True or False in case of success or failure
"""
logger.debug(f"Killing process id: {current_pid}")
try:
current_process = psutil.Process(current_pid)
if current_process.status() == "zombie":
return False
children = current_process.children(recursive=True)
for child in children:
if child.pid == exclude:
continue
child.kill()
return True
except psutil.NoSuchProcess:
logger.warning(f"Cannot kill process id: {current_pid}. No such process.")
return False
def kill_child_processes_and_current(current_pid: Optional[int] = None) -> bool:
"""
Kill all child processes of the current process, then terminates itself.
Optionally, specify the current process id.
If not specified, uses the current process id.
Args:
current_pid: current process id (default: None)
Returns:
True or False in case of success or failure
"""
if current_pid is None:
current_pid = os.getpid()
kill_child_processes(current_pid)
try:
current_process = psutil.Process(current_pid)
current_process.kill()
return True
except psutil.NoSuchProcess:
logger.warning(f"Cannot kill process id: {current_pid}. No such process.")
return False
def kill_sibling_ddp_processes() -> None:
"""
Killing all sibling DDP processes from a single DDP process.
"""
pid = os.getpid()
parent_pid = os.getppid()
kill_child_processes(parent_pid, exclude=pid)
current_process = psutil.Process(pid)
current_process.kill()
def add_file_to_zip(zf: zipfile.ZipFile, path: str, folder=None) -> None:
"""Adds a file to the existing zip. Does nothing if file does not exist.
Args:
zf: zipfile object to add to
path: path to the file to add
folder: folder in the zip to add the file to
"""
try:
if folder is None:
zip_path = os.path.basename(path)
else:
zip_path = os.path.join(folder, os.path.basename(path))
zf.write(path, zip_path)
except Exception:
logger.warning(f"File {path} could not be added to zip.")
def save_pickle(path: str, obj: Any, protocol: int = 4) -> None:
"""Saves object as pickle file
Args:
path: path of file to save
obj: object to save
protocol: protocol to use when saving pickle
"""
with open(path, "wb") as pickle_file:
pickle.dump(obj, pickle_file, protocol=protocol)
class DisableLogger:
def __init__(self, level: int = logging.INFO):
self.level = level
def __enter__(self):
logging.disable(self.level)
def __exit__(self, exit_type, exit_value, exit_traceback):
logging.disable(logging.NOTSET)
class PatchedAttribute:
"""
Patches an attribute of an object for the duration of this context manager.
Similar to unittest.mock.patch,
but works also for properties that are not present in the original class
>>> class MyObj:
... attr = 'original'
>>> my_obj = MyObj()
>>> with PatchedAttribute(my_obj, 'attr', 'patched'):
... print(my_obj.attr)
patched
>>> print(my_obj.attr)
original
>>> with PatchedAttribute(my_obj, 'new_attr', 'new_patched'):
... print(my_obj.new_attr)
new_patched
>>> assert not hasattr(my_obj, 'new_attr')
"""
def __init__(self, obj, attribute, new_value):
self.obj = obj
self.attribute = attribute
self.new_value = new_value
self.original_exists = hasattr(obj, attribute)
if self.original_exists:
self.original_value = getattr(obj, attribute)
def __enter__(self):
setattr(self.obj, self.attribute, self.new_value)
def __exit__(self, exc_type, exc_val, exc_tb):
if self.original_exists:
setattr(self.obj, self.attribute, self.original_value)
else:
delattr(self.obj, self.attribute)
def create_symlinks_in_parent_folder(directory):
"""Creates symlinks for each item in a folder in the parent folder
Only creates symlinks for items at the root level of the directory.
"""
if not os.path.exists(directory):
raise FileNotFoundError(f"Directory {directory} does not exist.")
parent_directory = os.path.dirname(directory)
for file in os.listdir(directory):
src = os.path.join(directory, file)
dst = os.path.join(parent_directory, file)
if os.path.exists(dst):
os.remove(dst)
os.symlink(src, dst)