Spaces:
Build error
Build error
import os | |
import comet.src.data.atomic as atomic_data | |
import comet.src.data.conceptnet as conceptnet_data | |
import comet.src.data.config as cfg | |
import comet.utils.utils as utils | |
import pickle | |
import torch | |
import json | |
start_token = "<START>" | |
end_token = "<END>" | |
blank_token = "<blank>" | |
def save_checkpoint(state, filename): | |
print("Saving model to {}".format(filename)) | |
torch.save(state, filename) | |
def save_step(model, vocab, optimizer, opt, length, lrs): | |
if cfg.test_save: | |
name = "{}.pickle".format(utils.make_name( | |
opt, prefix="garbage/models/", is_dir=False, eval_=True)) | |
else: | |
name = "{}.pickle".format(utils.make_name( | |
opt, prefix="models/", is_dir=False, eval_=True)) | |
save_checkpoint({ | |
"epoch": length, "state_dict": model.state_dict(), | |
"optimizer": optimizer.state_dict(), "opt": opt, | |
"vocab": vocab, "epoch_learning_rates": lrs}, | |
name) | |
def save_eval_file(opt, stats, eval_type="losses", split="dev", ext="pickle"): | |
if cfg.test_save: | |
name = "{}/{}.{}".format(utils.make_name( | |
opt, prefix="garbage/{}/".format(eval_type), | |
is_dir=True, eval_=True), split, ext) | |
else: | |
name = "{}/{}.{}".format(utils.make_name( | |
opt, prefix="results/{}/".format(eval_type), | |
is_dir=True, eval_=True), split, ext) | |
print("Saving {} {} to {}".format(split, eval_type, name)) | |
if ext == "pickle": | |
with open(name, "wb") as f: | |
pickle.dump(stats, f) | |
elif ext == "txt": | |
with open(name, "w") as f: | |
f.write(stats) | |
elif ext == "json": | |
with open(name, "w") as f: | |
json.dump(stats, f) | |
else: | |
raise | |
def load_checkpoint(filename, gpu=True): | |
if os.path.exists(filename): | |
checkpoint = torch.load( | |
filename, map_location=lambda storage, loc: storage) | |
else: | |
print("No model found at {}".format(filename)) | |
return checkpoint | |
def make_data_loader(opt, *args): | |
if opt.dataset == "atomic": | |
return atomic_data.GenerationDataLoader(opt, *args) | |
elif opt.dataset == "conceptnet": | |
return conceptnet_data.GenerationDataLoader(opt, *args) | |
def set_max_sizes(data_loader, force_split=None): | |
data_loader.total_size = {} | |
if force_split is not None: | |
data_loader.total_size[force_split] = \ | |
data_loader.sequences[force_split]["total"].size(0) | |
return | |
for split in data_loader.sequences: | |
data_loader.total_size[split] = \ | |
data_loader.sequences[split]["total"].size(0) | |