Spaces:
Sleeping
Sleeping
import os | |
import random | |
import numpy as np | |
import torch | |
import math | |
import time | |
import datetime | |
import json | |
from json import encoder | |
FORMAT_INFO = { | |
"inchi": { | |
"name": "InChI_text", | |
"tokenizer": "tokenizer_inchi.json", | |
"max_len": 300 | |
}, | |
"atomtok": { | |
"name": "SMILES_atomtok", | |
"tokenizer": "tokenizer_smiles_atomtok.json", | |
"max_len": 256 | |
}, | |
"nodes": {"max_len": 384}, | |
"atomtok_coords": {"max_len": 480}, | |
"chartok_coords": {"max_len": 480} | |
} | |
def init_logger(log_file='train.log'): | |
from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler | |
logger = getLogger(__name__) | |
logger.setLevel(INFO) | |
handler1 = StreamHandler() | |
handler1.setFormatter(Formatter("%(message)s")) | |
handler2 = FileHandler(filename=log_file) | |
handler2.setFormatter(Formatter("%(message)s")) | |
logger.addHandler(handler1) | |
logger.addHandler(handler2) | |
return logger | |
def init_summary_writer(save_path): | |
from tensorboardX import SummaryWriter | |
summary = SummaryWriter(save_path) | |
return summary | |
def save_args(args): | |
dt = datetime.datetime.strftime(datetime.datetime.now(), "%y%m%d-%H%M") | |
path = os.path.join(args.save_path, f'train_{dt}.log') | |
with open(path, 'w') as f: | |
for k, v in vars(args).items(): | |
f.write(f"**** {k} = *{v}*\n") | |
return | |
def seed_torch(seed=42): | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
class EpochMeter(AverageMeter): | |
def __init__(self): | |
super().__init__() | |
self.epoch = AverageMeter() | |
def update(self, val, n=1): | |
super().update(val, n) | |
self.epoch.update(val, n) | |
class LossMeter(EpochMeter): | |
def __init__(self): | |
self.subs = {} | |
super().__init__() | |
def reset(self): | |
super().reset() | |
for k in self.subs: | |
self.subs[k].reset() | |
def update(self, loss, losses, n=1): | |
loss = loss.item() | |
super().update(loss, n) | |
losses = {k: v.item() for k, v in losses.items()} | |
for k, v in losses.items(): | |
if k not in self.subs: | |
self.subs[k] = EpochMeter() | |
self.subs[k].update(v, n) | |
def asMinutes(s): | |
m = math.floor(s / 60) | |
s -= m * 60 | |
return '%dm %ds' % (m, s) | |
def timeSince(since, percent): | |
now = time.time() | |
s = now - since | |
es = s / (percent) | |
rs = es - s | |
return '%s (remain %s)' % (asMinutes(s), asMinutes(rs)) | |
def print_rank_0(message): | |
if torch.distributed.is_initialized(): | |
if torch.distributed.get_rank() == 0: | |
print(message, flush=True) | |
else: | |
print(message, flush=True) | |
def to_device(data, device): | |
if torch.is_tensor(data): | |
return data.to(device) | |
if type(data) is list: | |
return [to_device(v, device) for v in data] | |
if type(data) is dict: | |
return {k: to_device(v, device) for k, v in data.items()} | |
def round_floats(o): | |
if isinstance(o, float): | |
return round(o, 3) | |
if isinstance(o, dict): | |
return {k: round_floats(v) for k, v in o.items()} | |
if isinstance(o, (list, tuple)): | |
return [round_floats(x) for x in o] | |
return o | |
def format_df(df): | |
def _dumps(obj): | |
if obj is None: | |
return obj | |
return json.dumps(round_floats(obj)).replace(" ", "") | |
for field in ['node_coords', 'node_symbols', 'edges']: | |
if field in df.columns: | |
df[field] = [_dumps(obj) for obj in df[field]] | |
return df | |