wordle-solver / a3c /discrete_A3C.py
santit96's picture
Allow to pass hiperparameters as command line arguments
f05ece6
raw
history blame
4.07 kB
"""
Reinforcement Learning (A3C) using Pytroch + multiprocessing.
The most simple implementation for continuous action.
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import torch.multiprocessing as mp
from .utils import v_wrap, set_init, push_and_pull, record
import numpy as np
GAMMA = 0.9
class Net(nn.Module):
def __init__(self, s_dim, a_dim, word_list, words_width):
super(Net, self).__init__()
self.s_dim = s_dim
self.a_dim = a_dim
n_emb = 32
# self.pi1 = nn.Linear(s_dim, 128)
# self.pi2 = nn.Linear(128, a_dim)
self.v1 = nn.Linear(s_dim, 256)
self.v2 = nn.Linear(256, n_emb)
self.v3 = nn.Linear(n_emb, 1)
set_init([ self.v1, self.v2]) # n_emb
self.distribution = torch.distributions.Categorical
word_width = 26 * words_width
word_array = np.zeros((len(word_list), word_width))
self.actor_head = nn.Linear(n_emb, n_emb)
for i, word in enumerate(word_list):
for j, c in enumerate(word):
word_array[i, j*26 + (ord(c) - ord('A'))] = 1
self.words = torch.Tensor(word_array)
self.f_word = nn.Sequential(
nn.Linear(word_width, 64),
nn.Tanh(),
nn.Linear(64, n_emb),
)
def forward(self, x):
# pi1 = torch.tanh(self.pi1(x))
fw = self.f_word(
self.words.to(x.device.index),
).transpose(0, 1)
# logits = self.pi2(pi1)
v1 = torch.tanh(self.v1(x))
values = self.v2(v1)
logits = torch.log_softmax(
torch.tensordot(self.actor_head(values), fw,
dims=((1,), (0,))),
dim=-1)
values = self.v3(values)
return logits, values
def choose_action(self, s):
self.eval()
logits, _ = self.forward(s)
prob = F.softmax(logits, dim=1).data
m = self.distribution(prob)
return m.sample().numpy()[0]
def loss_func(self, s, a, v_t):
self.train()
logits, values = self.forward(s)
td = v_t - values
c_loss = td.pow(2)
probs = F.softmax(logits, dim=1)
m = self.distribution(probs)
exp_v = m.log_prob(a) * td.detach().squeeze()
a_loss = -exp_v
total_loss = (c_loss + a_loss).mean()
return total_loss
class Worker(mp.Process):
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
super(Worker, self).__init__()
self.max_ep = max_ep
self.name = 'w%02i' % name
self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
self.gnet, self.opt = gnet, opt
self.word_list = words_list
self.lnet = Net(N_S, N_A, words_list, word_width) # local network
self.env = env.unwrapped
def run(self):
while self.g_ep.value < self.max_ep:
s = self.env.reset()
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.
while True:
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(a)
ep_r += r
buffer_a.append(a)
buffer_s.append(s)
buffer_r.append(r)
if done: # update global and assign to local net
# sync
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
goal_word = self.word_list[self.env.goal_word]
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
buffer_s, buffer_a, buffer_r = [], [], []
break
s = s_
self.res_queue.put(None)