Spaces:
Sleeping
Sleeping
""" | |
Reinforcement Learning (A3C) using Pytroch + multiprocessing. | |
The most simple implementation for continuous action. | |
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/). | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import gym | |
import torch.multiprocessing as mp | |
from .utils import v_wrap, set_init, push_and_pull, record | |
import numpy as np | |
GAMMA = 0.9 | |
class Net(nn.Module): | |
def __init__(self, s_dim, a_dim, word_list, words_width): | |
super(Net, self).__init__() | |
self.s_dim = s_dim | |
self.a_dim = a_dim | |
n_emb = 32 | |
# self.pi1 = nn.Linear(s_dim, 128) | |
# self.pi2 = nn.Linear(128, a_dim) | |
self.v1 = nn.Linear(s_dim, 256) | |
self.v2 = nn.Linear(256, n_emb) | |
self.v3 = nn.Linear(n_emb, 1) | |
set_init([ self.v1, self.v2]) # n_emb | |
self.distribution = torch.distributions.Categorical | |
word_width = 26 * words_width | |
word_array = np.zeros((len(word_list), word_width)) | |
self.actor_head = nn.Linear(n_emb, n_emb) | |
for i, word in enumerate(word_list): | |
for j, c in enumerate(word): | |
word_array[i, j*26 + (ord(c) - ord('A'))] = 1 | |
self.words = torch.Tensor(word_array) | |
self.f_word = nn.Sequential( | |
nn.Linear(word_width, 64), | |
nn.Tanh(), | |
nn.Linear(64, n_emb), | |
) | |
def forward(self, x): | |
# pi1 = torch.tanh(self.pi1(x)) | |
fw = self.f_word( | |
self.words.to(x.device.index), | |
).transpose(0, 1) | |
# logits = self.pi2(pi1) | |
v1 = torch.tanh(self.v1(x)) | |
values = self.v2(v1) | |
logits = torch.log_softmax( | |
torch.tensordot(self.actor_head(values), fw, | |
dims=((1,), (0,))), | |
dim=-1) | |
values = self.v3(values) | |
return logits, values | |
def choose_action(self, s): | |
self.eval() | |
logits, _ = self.forward(s) | |
prob = F.softmax(logits, dim=1).data | |
m = self.distribution(prob) | |
return m.sample().numpy()[0] | |
def loss_func(self, s, a, v_t): | |
self.train() | |
logits, values = self.forward(s) | |
td = v_t - values | |
c_loss = td.pow(2) | |
probs = F.softmax(logits, dim=1) | |
m = self.distribution(probs) | |
exp_v = m.log_prob(a) * td.detach().squeeze() | |
a_loss = -exp_v | |
total_loss = (c_loss + a_loss).mean() | |
return total_loss | |
class Worker(mp.Process): | |
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep): | |
super(Worker, self).__init__() | |
self.max_ep = max_ep | |
self.name = 'w%02i' % name | |
self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep | |
self.gnet, self.opt = gnet, opt | |
self.word_list = words_list | |
self.lnet = Net(N_S, N_A, words_list, word_width) # local network | |
self.env = env.unwrapped | |
def run(self): | |
while self.g_ep.value < self.max_ep: | |
s = self.env.reset() | |
buffer_s, buffer_a, buffer_r = [], [], [] | |
ep_r = 0. | |
while True: | |
a = self.lnet.choose_action(v_wrap(s[None, :])) | |
s_, r, done, _ = self.env.step(a) | |
ep_r += r | |
buffer_a.append(a) | |
buffer_s.append(s) | |
buffer_r.append(r) | |
if done: # update global and assign to local net | |
# sync | |
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA) | |
goal_word = self.word_list[self.env.goal_word] | |
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep) | |
buffer_s, buffer_a, buffer_r = [], [], [] | |
break | |
s = s_ | |
self.res_queue.put(None) | |