wordle-solver / a3c /worker.py
santit96's picture
Add possibility to train from a pretrained model
18a7031
raw
history blame
4.08 kB
"""
Worker class implementation of the a3c discrete algorithm
"""
import os
import torch
import numpy as np
import torch.multiprocessing as mp
from torch import nn
from .net import Net
from .utils import v_wrap
GAMMA = 0.65
class Worker(mp.Process):
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir, pretrained_model_path=None):
super(Worker, self).__init__()
self.max_ep = max_ep
self.name = 'w%02i' % name
self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
self.gnet, self.opt = gnet, opt
self.word_list = words_list
# local network
self.lnet = Net(N_S, N_A, words_list, word_width)
if pretrained_model_path:
self.lnet.load_state_dict(torch.load(pretrained_model_path))
self.env = env.unwrapped
self.model_checkpoint_dir = model_checkpoint_dir
def run(self):
while self.g_ep.value < self.max_ep:
s = self.env.reset()
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.
while True:
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(a)
ep_r += r
buffer_a.append(a)
buffer_s.append(s)
buffer_r.append(r)
if done: # update global and assign to local net
# sync
self.push_and_pull(done, s_, buffer_s,
buffer_a, buffer_r, GAMMA)
goal_word = self.word_list[self.env.goal_word]
self.record(ep_r, goal_word,
self.word_list[a], len(buffer_a))
self.save_model()
buffer_s, buffer_a, buffer_r = [], [], []
break
s = s_
self.res_queue.put(None)
def push_and_pull(self, done, s_, bs, ba, br, gamma):
if done:
v_s_ = 0. # terminal
else:
v_s_ = self.lnet.forward(v_wrap(
s_[None, :]))[-1].data.numpy()[0, 0]
buffer_v_target = []
for r in br[::-1]: # reverse buffer r
v_s_ = r + gamma * v_s_
buffer_v_target.append(v_s_)
buffer_v_target.reverse()
loss = self.lnet.loss_func(
v_wrap(np.vstack(bs)),
v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)),
v_wrap(np.array(buffer_v_target)[:, None]))
# calculate local gradients and push local parameters to global
self.opt.zero_grad()
loss.backward()
for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()):
gp._grad = lp.grad
self.opt.step()
# pull global parameters
self.lnet.load_state_dict(self.gnet.state_dict())
def save_model(self):
if self.g_ep_r.value >= 9.9 and self.g_ep.value % 100 == 0:
torch.save(self.gnet.state_dict(), os.path.join(
self.model_checkpoint_dir, f'model_{ self.g_ep.value}.pth'))
def record(self, ep_r, goal_word, action, action_number):
with self.g_ep.get_lock():
self.g_ep.value += 1
with self.g_ep_r.get_lock():
if self.g_ep_r.value == 0.:
self.g_ep_r.value = ep_r
else:
self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
self.res_queue.put(self.g_ep_r.value)
if goal_word == action:
self.winning_ep.value += 1
if self.g_ep.value % 100 == 0:
print(
self.name,
"Ep:", self.g_ep.value,
"| Ep_r: %.0f" % self.g_ep_r.value,
"| Goal :", goal_word,
"| Action: ", action,
"| Actions: ", action_number
)