wordle-solver / a3c /discrete_A3C.py
santit96's picture
Fix bug on env initialization
350e00d
raw
history blame
4.06 kB
"""
Reinforcement Learning (A3C) using Pytroch + multiprocessing.
The most simple implementation for continuous action.
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import torch.multiprocessing as mp
from .utils import v_wrap, set_init, push_and_pull, record
import numpy as np
GAMMA = 0.7
class Net(nn.Module):
def __init__(self, s_dim, a_dim, word_list, words_width):
super(Net, self).__init__()
self.s_dim = s_dim
self.a_dim = a_dim
# n_emb = 32
word_width = 26 * words_width
layers = [
nn.Linear(s_dim, word_width),
nn.Tanh(),
# nn.Linear(128, word_width),
# nn.Tanh(),
# nn.Linear(256, n_emb),
# nn.Tanh(),
]
self.v1 = nn.Sequential(*layers)
self.v4 = nn.Linear(word_width, 1)
self.actor_head = nn.Linear(word_width, word_width)
self.distribution = torch.distributions.Categorical
word_array = np.zeros((word_width, len(word_list)))
for i, word in enumerate(word_list):
for j, c in enumerate(word):
word_array[ j*26 + (ord(c) - ord('A')), i ] = 1
self.words = torch.Tensor(word_array)
# self.f_word = nn.Sequential(
# nn.Linear(word_width, 64),
# nn.ReLU(),
# nn.Linear(64, n_emb),
# )
def forward(self, x):
# fw = self.f_word(
# self.words.to(x.device.index),
# ).transpose(0, 1)
values = self.v1(x.float())
logits = torch.log_softmax(
torch.tensordot(self.actor_head(values), self.words,
dims=((1,), (0,))),
dim=-1)
values = self.v4(values)
return logits, values
def choose_action(self, s):
self.eval()
logits, _ = self.forward(s)
prob = F.softmax(logits, dim=1).data
m = self.distribution(prob)
return m.sample().numpy()[0]
def loss_func(self, s, a, v_t):
self.train()
logits, values = self.forward(s)
td = v_t - values
c_loss = td.pow(2)
probs = F.softmax(logits, dim=1)
m = self.distribution(probs)
exp_v = m.log_prob(a) * td.detach().squeeze()
a_loss = -exp_v
total_loss = (c_loss + a_loss).mean()
return total_loss
class Worker(mp.Process):
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
super(Worker, self).__init__()
self.max_ep = max_ep
self.name = 'w%02i' % name
self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
self.gnet, self.opt = gnet, opt
self.word_list = words_list
self.lnet = Net(N_S, N_A, words_list, word_width) # local network
self.env = env.unwrapped
def run(self):
while self.g_ep.value < self.max_ep:
s = self.env.reset()
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.
while True:
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(a)
ep_r += r
buffer_a.append(a)
buffer_s.append(s)
buffer_r.append(r)
if done: # update global and assign to local net
# sync
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
goal_word = self.word_list[self.env.goal_word]
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
buffer_s, buffer_a, buffer_r = [], [], []
break
s = s_
self.res_queue.put(None)