wordle-solver / a3c /discrete_A3C.py
santit96's picture
Now the one hot encoding of the words is embeded inside the net, is not necesary to one hot encod on the env now
abff1ef
raw
history blame
4.33 kB
"""
Reinforcement Learning (A3C) using Pytroch + multiprocessing.
The most simple implementation for continuous action.
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import torch.multiprocessing as mp
from .utils import v_wrap, set_init, push_and_pull, record
import numpy as np
UPDATE_GLOBAL_ITER = 5
GAMMA = 0.9
MAX_EP = 100000
class Net(nn.Module):
def __init__(self, s_dim, a_dim, word_list, words_width):
super(Net, self).__init__()
self.s_dim = s_dim
self.a_dim = a_dim
n_emb = 32
# self.pi1 = nn.Linear(s_dim, 128)
# self.pi2 = nn.Linear(128, a_dim)
self.v1 = nn.Linear(s_dim, 128)
self.v2 = nn.Linear(128, n_emb)
self.v3 = nn.Linear(n_emb, 1)
set_init([ self.v1, self.v2]) # n_emb
self.distribution = torch.distributions.Categorical
assert a_dim == len(word_list), "putos"
word_width = 26 * words_width
word_array = np.zeros((len(word_list), word_width))
self.actor_head = nn.Linear(n_emb, n_emb)
for i, word in enumerate(word_list):
for j, c in enumerate(word):
word_array[i, j*26 + (ord(c) - ord('A'))] = 1
self.words = torch.Tensor(word_array)
self.f_word = nn.Sequential(
nn.Linear(word_width, 128),
nn.Tanh(),
nn.Linear(128, n_emb),
)
def forward(self, x):
# pi1 = torch.tanh(self.pi1(x))
fw = self.f_word(
self.words.to(x.device.index),
).transpose(0, 1)
# logits = self.pi2(pi1)
v1 = torch.tanh(self.v1(x))
values = self.v2(v1)
logits = torch.log_softmax(
torch.tensordot(self.actor_head(values), fw,
dims=((1,), (0,))),
dim=-1)
values = self.v3(values)
return logits, values
def choose_action(self, s):
self.eval()
logits, _ = self.forward(s)
prob = F.softmax(logits, dim=1).data
m = self.distribution(prob)
return m.sample().numpy()[0]
def loss_func(self, s, a, v_t):
self.train()
logits, values = self.forward(s)
td = v_t - values
c_loss = td.pow(2)
probs = F.softmax(logits, dim=1)
m = self.distribution(probs)
exp_v = m.log_prob(a) * td.detach().squeeze()
a_loss = -exp_v
total_loss = (c_loss + a_loss).mean()
return total_loss
class Worker(mp.Process):
def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width):
super(Worker, self).__init__()
self.name = 'w%02i' % name
self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
self.gnet, self.opt = gnet, opt
self.word_list = words_list
self.lnet = Net(N_S, N_A, words_list, word_width) # local network
self.env = env.unwrapped
def run(self):
total_step = 1
while self.g_ep.value < MAX_EP:
s = self.env.reset()
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.
while True:
if self.name == 'w00':
self.env.render()
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(self.env.encode_word(self.word_list[a]))
ep_r += r
buffer_a.append(a)
buffer_s.append(s)
buffer_r.append(r)
if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net
# sync
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
buffer_s, buffer_a, buffer_r = [], [], []
if done: # done and print information
goal_word = self.env.decode_word(self.env.goal_word)
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a])
break
s = s_
total_step += 1
self.res_queue.put(None)