""" Reinforcement Learning (A3C) using Pytroch + multiprocessing. The most simple implementation for continuous action. View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/). """ import torch import torch.nn as nn import torch.nn.functional as F import gym import torch.multiprocessing as mp from .utils import v_wrap, set_init, push_and_pull, record import numpy as np GAMMA = 0.9 class Net(nn.Module): def __init__(self, s_dim, a_dim, word_list, words_width): super(Net, self).__init__() self.s_dim = s_dim self.a_dim = a_dim n_emb = 32 # self.pi1 = nn.Linear(s_dim, 128) # self.pi2 = nn.Linear(128, a_dim) self.v1 = nn.Linear(s_dim, 256) self.v2 = nn.Linear(256, n_emb) self.v3 = nn.Linear(n_emb, 1) set_init([ self.v1, self.v2]) # n_emb self.distribution = torch.distributions.Categorical word_width = 26 * words_width word_array = np.zeros((len(word_list), word_width)) self.actor_head = nn.Linear(n_emb, n_emb) for i, word in enumerate(word_list): for j, c in enumerate(word): word_array[i, j*26 + (ord(c) - ord('A'))] = 1 self.words = torch.Tensor(word_array) self.f_word = nn.Sequential( nn.Linear(word_width, 64), nn.Tanh(), nn.Linear(64, n_emb), ) def forward(self, x): # pi1 = torch.tanh(self.pi1(x)) fw = self.f_word( self.words.to(x.device.index), ).transpose(0, 1) # logits = self.pi2(pi1) v1 = torch.tanh(self.v1(x)) values = self.v2(v1) logits = torch.log_softmax( torch.tensordot(self.actor_head(values), fw, dims=((1,), (0,))), dim=-1) values = self.v3(values) return logits, values def choose_action(self, s): self.eval() logits, _ = self.forward(s) prob = F.softmax(logits, dim=1).data m = self.distribution(prob) return m.sample().numpy()[0] def loss_func(self, s, a, v_t): self.train() logits, values = self.forward(s) td = v_t - values c_loss = td.pow(2) probs = F.softmax(logits, dim=1) m = self.distribution(probs) exp_v = m.log_prob(a) * td.detach().squeeze() a_loss = -exp_v total_loss = (c_loss + a_loss).mean() return total_loss class Worker(mp.Process): def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep): super(Worker, self).__init__() self.max_ep = max_ep self.name = 'w%02i' % name self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep self.gnet, self.opt = gnet, opt self.word_list = words_list self.lnet = Net(N_S, N_A, words_list, word_width) # local network self.env = env.unwrapped def run(self): while self.g_ep.value < self.max_ep: s = self.env.reset() buffer_s, buffer_a, buffer_r = [], [], [] ep_r = 0. while True: a = self.lnet.choose_action(v_wrap(s[None, :])) s_, r, done, _ = self.env.step(a) ep_r += r buffer_a.append(a) buffer_s.append(s) buffer_r.append(r) if done: # update global and assign to local net # sync push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA) goal_word = self.word_list[self.env.goal_word] record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep) buffer_s, buffer_a, buffer_r = [], [], [] break s = s_ self.res_queue.put(None)