Minh Q. Le
Pushed COSMIC code
a446b0b
import os
import comet.src.data.atomic as atomic_data
import comet.src.data.conceptnet as conceptnet_data
import comet.src.data.config as cfg
import comet.utils.utils as utils
import pickle
import torch
import json
start_token = "<START>"
end_token = "<END>"
blank_token = "<blank>"
def save_checkpoint(state, filename):
print("Saving model to {}".format(filename))
torch.save(state, filename)
def save_step(model, vocab, optimizer, opt, length, lrs):
if cfg.test_save:
name = "{}.pickle".format(utils.make_name(
opt, prefix="garbage/models/", is_dir=False, eval_=True))
else:
name = "{}.pickle".format(utils.make_name(
opt, prefix="models/", is_dir=False, eval_=True))
save_checkpoint({
"epoch": length, "state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(), "opt": opt,
"vocab": vocab, "epoch_learning_rates": lrs},
name)
def save_eval_file(opt, stats, eval_type="losses", split="dev", ext="pickle"):
if cfg.test_save:
name = "{}/{}.{}".format(utils.make_name(
opt, prefix="garbage/{}/".format(eval_type),
is_dir=True, eval_=True), split, ext)
else:
name = "{}/{}.{}".format(utils.make_name(
opt, prefix="results/{}/".format(eval_type),
is_dir=True, eval_=True), split, ext)
print("Saving {} {} to {}".format(split, eval_type, name))
if ext == "pickle":
with open(name, "wb") as f:
pickle.dump(stats, f)
elif ext == "txt":
with open(name, "w") as f:
f.write(stats)
elif ext == "json":
with open(name, "w") as f:
json.dump(stats, f)
else:
raise
def load_checkpoint(filename, gpu=True):
if os.path.exists(filename):
checkpoint = torch.load(
filename, map_location=lambda storage, loc: storage)
else:
print("No model found at {}".format(filename))
return checkpoint
def make_data_loader(opt, *args):
if opt.dataset == "atomic":
return atomic_data.GenerationDataLoader(opt, *args)
elif opt.dataset == "conceptnet":
return conceptnet_data.GenerationDataLoader(opt, *args)
def set_max_sizes(data_loader, force_split=None):
data_loader.total_size = {}
if force_split is not None:
data_loader.total_size[force_split] = \
data_loader.sequences[force_split]["total"].size(0)
return
for split in data_loader.sequences:
data_loader.total_size[split] = \
data_loader.sequences[split]["total"].size(0)