wordle-solver / a3c /worker.py
santit96's picture
Add the posiblity to save checkpoints of the model and the condition on which the model is saved as arguments
fa34b1d
raw
history blame
4.53 kB
"""
Worker class implementation of the a3c discrete algorithm
"""
import os
import torch
import numpy as np
import torch.multiprocessing as mp
from torch import nn
from .net import Net
from .utils import v_wrap
class Worker(mp.Process):
def __init__(
self,
max_ep,
gnet,
opt,
global_ep,
global_ep_r,
res_queue,
name,
env,
N_S,
N_A,
words_list,
word_width,
winning_ep,
model_checkpoint_dir,
gamma=0.,
pretrained_model_path=None,
save=False,
min_reward=9.9,
every_n_save=100
):
super(Worker, self).__init__()
self.max_ep = max_ep
self.name = 'w%02i' % name
self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
self.gnet, self.opt = gnet, opt
self.word_list = words_list
# local network
self.lnet = Net(N_S, N_A, words_list, word_width)
if pretrained_model_path:
self.lnet.load_state_dict(torch.load(pretrained_model_path))
self.env = env.unwrapped
self.gamma = gamma
self.model_checkpoint_dir = model_checkpoint_dir
self.save = save
self.min_reward = min_reward
self.every_n_save = every_n_save
def run(self):
while self.g_ep.value < self.max_ep:
s = self.env.reset()
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.
while True:
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(a)
ep_r += r
buffer_a.append(a)
buffer_s.append(s)
buffer_r.append(r)
if done: # update global and assign to local net
# sync
self.push_and_pull(done, s_, buffer_s,
buffer_a, buffer_r)
goal_word = self.word_list[self.env.goal_word]
self.record(ep_r, goal_word,
self.word_list[a], len(buffer_a))
self.save_model()
buffer_s, buffer_a, buffer_r = [], [], []
break
s = s_
self.res_queue.put(None)
def push_and_pull(self, done, s_, bs, ba, br):
if done:
v_s_ = 0. # terminal
else:
v_s_ = self.lnet.forward(v_wrap(
s_[None, :]))[-1].data.numpy()[0, 0]
buffer_v_target = []
for r in br[::-1]: # reverse buffer r
v_s_ = r + self.gamma * v_s_
buffer_v_target.append(v_s_)
buffer_v_target.reverse()
loss = self.lnet.loss_func(
v_wrap(np.vstack(bs)),
v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)),
v_wrap(np.array(buffer_v_target)[:, None]))
# calculate local gradients and push local parameters to global
self.opt.zero_grad()
loss.backward()
for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()):
gp._grad = lp.grad
self.opt.step()
# pull global parameters
self.lnet.load_state_dict(self.gnet.state_dict())
def save_model(self):
if self.save and self.g_ep_r.value >= self.min_reward and self.g_ep.value % self.every_n_save == 0:
torch.save(self.gnet.state_dict(), os.path.join(
self.model_checkpoint_dir, f'model_{self.g_ep.value}.pth'))
def record(self, ep_r, goal_word, action, action_number):
with self.g_ep.get_lock():
self.g_ep.value += 1
with self.g_ep_r.get_lock():
if self.g_ep_r.value == 0.:
self.g_ep_r.value = ep_r
else:
self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
self.res_queue.put(self.g_ep_r.value)
if goal_word == action:
self.winning_ep.value += 1
if self.g_ep.value % 100 == 0:
print(
self.name,
"Ep:", self.g_ep.value,
"| Ep_r: %.0f" % self.g_ep_r.value,
"| Goal :", goal_word,
"| Action: ", action,
"| Actions: ", action_number
)