""" Worker class implementation of the a3c discrete algorithm """ import os import numpy as np import torch import torch.multiprocessing as mp from torch import nn from .net import Net from .utils import v_wrap class Worker(mp.Process): def __init__( self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir, gamma=0.0, pretrained_model_path=None, save=False, min_reward=9.9, every_n_save=100, ): super(Worker, self).__init__() self.max_ep = max_ep self.name = "w%02i" % name self.g_ep = global_ep self.g_ep_r = global_ep_r self.res_queue = res_queue self.winning_ep = winning_ep self.gnet, self.opt = gnet, opt self.word_list = words_list # local network self.lnet = Net(N_S, N_A, words_list, word_width) if pretrained_model_path: self.lnet.load_state_dict(torch.load(pretrained_model_path)) self.env = env.unwrapped self.gamma = gamma self.model_checkpoint_dir = model_checkpoint_dir self.save = save self.min_reward = min_reward self.every_n_save = every_n_save def run(self): while self.g_ep.value < self.max_ep: s = self.env.reset() buffer_s, buffer_a, buffer_r = [], [], [] ep_r = 0.0 while True: a = self.lnet.choose_action(v_wrap(s[None, :])) s_, r, done, _ = self.env.step(a) ep_r += r buffer_a.append(a) buffer_s.append(s) buffer_r.append(r) if done: # update global and assign to local net # sync self.push_and_pull(done, s_, buffer_s, buffer_a, buffer_r) goal_word = self.word_list[self.env.goal_word] self.record(ep_r, goal_word, self.word_list[a], len(buffer_a)) self.save_model() buffer_s, buffer_a, buffer_r = [], [], [] break s = s_ self.res_queue.put(None) def push_and_pull(self, done, s_, bs, ba, br): if done: v_s_ = 0.0 # terminal else: v_s_ = self.lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0] buffer_v_target = [] for r in br[::-1]: # reverse buffer r v_s_ = r + self.gamma * v_s_ buffer_v_target.append(v_s_) buffer_v_target.reverse() loss = self.lnet.loss_func( v_wrap(np.vstack(bs)), v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)), v_wrap(np.array(buffer_v_target)[:, None]), ) # calculate local gradients and push local parameters to global self.opt.zero_grad() loss.backward() for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()): gp._grad = lp.grad self.opt.step() # pull global parameters self.lnet.load_state_dict(self.gnet.state_dict()) def save_model(self): if ( self.save and self.g_ep_r.value >= self.min_reward and self.g_ep.value % self.every_n_save == 0 ): torch.save( self.gnet.state_dict(), os.path.join(self.model_checkpoint_dir, f"model_{self.g_ep.value}.pth"), ) def record(self, ep_r, goal_word, action, action_number): with self.g_ep.get_lock(): self.g_ep.value += 1 with self.g_ep_r.get_lock(): if self.g_ep_r.value == 0.0: self.g_ep_r.value = ep_r else: self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01 self.res_queue.put(self.g_ep_r.value) if goal_word == action: self.winning_ep.value += 1 if self.g_ep.value % 100 == 0: print( self.name, "Ep:", self.g_ep.value, "| Ep_r: %.0f" % self.g_ep_r.value, "| Goal :", goal_word, "| Action: ", action, "| Actions: ", action_number, )