Spaces:
Sleeping
Sleeping
File size: 4,528 Bytes
254d61f fa34b1d 254d61f 18a7031 254d61f f899dd3 254d61f fa34b1d 254d61f f899dd3 254d61f f899dd3 254d61f f899dd3 254d61f fa34b1d 254d61f fa34b1d 254d61f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
"""
Worker class implementation of the a3c discrete algorithm
"""
import os
import torch
import numpy as np
import torch.multiprocessing as mp
from torch import nn
from .net import Net
from .utils import v_wrap
class Worker(mp.Process):
def __init__(
self,
max_ep,
gnet,
opt,
global_ep,
global_ep_r,
res_queue,
name,
env,
N_S,
N_A,
words_list,
word_width,
winning_ep,
model_checkpoint_dir,
gamma=0.,
pretrained_model_path=None,
save=False,
min_reward=9.9,
every_n_save=100
):
super(Worker, self).__init__()
self.max_ep = max_ep
self.name = 'w%02i' % name
self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
self.gnet, self.opt = gnet, opt
self.word_list = words_list
# local network
self.lnet = Net(N_S, N_A, words_list, word_width)
if pretrained_model_path:
self.lnet.load_state_dict(torch.load(pretrained_model_path))
self.env = env.unwrapped
self.gamma = gamma
self.model_checkpoint_dir = model_checkpoint_dir
self.save = save
self.min_reward = min_reward
self.every_n_save = every_n_save
def run(self):
while self.g_ep.value < self.max_ep:
s = self.env.reset()
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.
while True:
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(a)
ep_r += r
buffer_a.append(a)
buffer_s.append(s)
buffer_r.append(r)
if done: # update global and assign to local net
# sync
self.push_and_pull(done, s_, buffer_s,
buffer_a, buffer_r)
goal_word = self.word_list[self.env.goal_word]
self.record(ep_r, goal_word,
self.word_list[a], len(buffer_a))
self.save_model()
buffer_s, buffer_a, buffer_r = [], [], []
break
s = s_
self.res_queue.put(None)
def push_and_pull(self, done, s_, bs, ba, br):
if done:
v_s_ = 0. # terminal
else:
v_s_ = self.lnet.forward(v_wrap(
s_[None, :]))[-1].data.numpy()[0, 0]
buffer_v_target = []
for r in br[::-1]: # reverse buffer r
v_s_ = r + self.gamma * v_s_
buffer_v_target.append(v_s_)
buffer_v_target.reverse()
loss = self.lnet.loss_func(
v_wrap(np.vstack(bs)),
v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)),
v_wrap(np.array(buffer_v_target)[:, None]))
# calculate local gradients and push local parameters to global
self.opt.zero_grad()
loss.backward()
for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()):
gp._grad = lp.grad
self.opt.step()
# pull global parameters
self.lnet.load_state_dict(self.gnet.state_dict())
def save_model(self):
if self.save and self.g_ep_r.value >= self.min_reward and self.g_ep.value % self.every_n_save == 0:
torch.save(self.gnet.state_dict(), os.path.join(
self.model_checkpoint_dir, f'model_{self.g_ep.value}.pth'))
def record(self, ep_r, goal_word, action, action_number):
with self.g_ep.get_lock():
self.g_ep.value += 1
with self.g_ep_r.get_lock():
if self.g_ep_r.value == 0.:
self.g_ep_r.value = ep_r
else:
self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
self.res_queue.put(self.g_ep_r.value)
if goal_word == action:
self.winning_ep.value += 1
if self.g_ep.value % 100 == 0:
print(
self.name,
"Ep:", self.g_ep.value,
"| Ep_r: %.0f" % self.g_ep_r.value,
"| Goal :", goal_word,
"| Action: ", action,
"| Actions: ", action_number
)
|