Spaces:
Running
Running
import os | |
import glob | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
def parse_filelist(filelist_path, split_char="|"): | |
with open(filelist_path, encoding='utf-8') as f: | |
filepaths_and_text = [line.strip().split(split_char) for line in f] | |
return filepaths_and_text | |
def load_model(model, saved_state_dict): | |
state_dict = model.state_dict() | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
try: | |
new_state_dict[k] = saved_state_dict[k] | |
except: | |
print("%s is not in the checkpoint" % k) | |
new_state_dict[k] = v | |
model.load_state_dict(new_state_dict) | |
return model | |
def latest_checkpoint_path(dir_path, regex="grad_svc_*.pt"): | |
f_list = glob.glob(os.path.join(dir_path, regex)) | |
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) | |
x = f_list[-1] | |
return x | |
def load_checkpoint(logdir, model, num=None): | |
if num is None: | |
model_path = latest_checkpoint_path(logdir, regex="grad_svc_*.pt") | |
else: | |
model_path = os.path.join(logdir, f"grad_svc_{num}.pt") | |
print(f'Loading checkpoint {model_path}...') | |
model_dict = torch.load(model_path, map_location=lambda loc, storage: loc) | |
model.load_state_dict(model_dict, strict=False) | |
return model | |
def save_figure_to_numpy(fig): | |
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
return data | |
def plot_tensor(tensor): | |
plt.style.use('default') | |
fig, ax = plt.subplots(figsize=(12, 3)) | |
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') | |
plt.colorbar(im, ax=ax) | |
plt.tight_layout() | |
fig.canvas.draw() | |
data = save_figure_to_numpy(fig) | |
plt.close() | |
return data | |
def save_plot(tensor, savepath): | |
plt.style.use('default') | |
fig, ax = plt.subplots(figsize=(12, 3)) | |
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') | |
plt.colorbar(im, ax=ax) | |
plt.tight_layout() | |
fig.canvas.draw() | |
plt.savefig(savepath) | |
plt.close() | |
return | |
def print_error(info): | |
print(f"\033[31m {info} \033[0m") | |