wordle-solver / a3c /train.py
santit96's picture
Add the posiblity to save checkpoints of the model and the condition on which the model is saved as arguments
fa34b1d
raw
history blame
1.67 kB
import os
import torch
import torch.multiprocessing as mp
from .shared_adam import SharedAdam
from .net import Net
from .worker import Worker
def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=None, save=False, min_reward=9.9, every_n_save=100):
os.environ["OMP_NUM_THREADS"] = "1"
if not os.path.exists(model_checkpoint_dir):
os.makedirs(model_checkpoint_dir)
n_s = env.observation_space.shape[0]
n_a = env.action_space.n
words_list = env.words
word_width = len(env.words[0])
gnet = Net(n_s, n_a, words_list, word_width) # global network
if pretrained_model_path:
gnet.load_state_dict(torch.load(pretrained_model_path))
gnet.share_memory() # share the global parameters in multiprocessing
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
# parallel training
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
words_list, word_width, win_ep, model_checkpoint_dir, gamma, pretrained_model_path, save, min_reward, every_n_save) for i in range(mp.cpu_count())]
[w.start() for w in workers]
res = [] # record episode reward to plot
while True:
r = res_queue.get()
if r is not None:
res.append(r)
else:
break
[w.join() for w in workers]
if save:
torch.save(gnet.state_dict(), os.path.join(model_checkpoint_dir, f'model_{env.unwrapped.spec.id}.pth'))
return global_ep, win_ep, gnet, res