santit96 commited on
Commit
abff1ef
·
1 Parent(s): 44db2f9

Now the one hot encoding of the words is embeded inside the net, is not necesary to one hot encod on the env now

Browse files
Files changed (3) hide show
  1. a3c/discrete_A3C.py +40 -15
  2. a3c/utils.py +9 -6
  3. main.py +5 -3
a3c/discrete_A3C.py CHANGED
@@ -11,29 +11,52 @@ import torch.nn.functional as F
11
  import gym
12
  import torch.multiprocessing as mp
13
  from .utils import v_wrap, set_init, push_and_pull, record
14
-
15
 
16
  UPDATE_GLOBAL_ITER = 5
17
  GAMMA = 0.9
18
- MAX_EP = 3000
19
 
20
  class Net(nn.Module):
21
- def __init__(self, s_dim, a_dim):
22
  super(Net, self).__init__()
23
  self.s_dim = s_dim
24
  self.a_dim = a_dim
25
- self.pi1 = nn.Linear(s_dim, 128)
26
- self.pi2 = nn.Linear(128, a_dim)
 
27
  self.v1 = nn.Linear(s_dim, 128)
28
- self.v2 = nn.Linear(128, 1)
29
- set_init([self.pi1, self.pi2, self.v1, self.v2])
 
30
  self.distribution = torch.distributions.Categorical
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def forward(self, x):
33
- pi1 = torch.tanh(self.pi1(x))
34
- logits = self.pi2(pi1)
 
 
 
35
  v1 = torch.tanh(self.v1(x))
36
  values = self.v2(v1)
 
 
 
 
 
37
  return logits, values
38
 
39
  def choose_action(self, s):
@@ -58,13 +81,14 @@ class Net(nn.Module):
58
 
59
 
60
  class Worker(mp.Process):
61
- def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name, N_S, N_A):
62
  super(Worker, self).__init__()
63
  self.name = 'w%02i' % name
64
  self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
65
  self.gnet, self.opt = gnet, opt
66
- self.lnet = Net(N_S, N_A) # local network
67
- self.env = gym.make('WordleEnv100OneAction-v0').unwrapped
 
68
 
69
  def run(self):
70
  total_step = 1
@@ -76,8 +100,7 @@ class Worker(mp.Process):
76
  if self.name == 'w00':
77
  self.env.render()
78
  a = self.lnet.choose_action(v_wrap(s[None, :]))
79
- s_, r, done, _ = self.env.step(a)
80
- if done: r = -1
81
  ep_r += r
82
  buffer_a.append(a)
83
  buffer_s.append(s)
@@ -89,8 +112,10 @@ class Worker(mp.Process):
89
  buffer_s, buffer_a, buffer_r = [], [], []
90
 
91
  if done: # done and print information
92
- record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name)
 
93
  break
 
94
  s = s_
95
  total_step += 1
96
  self.res_queue.put(None)
 
11
  import gym
12
  import torch.multiprocessing as mp
13
  from .utils import v_wrap, set_init, push_and_pull, record
14
+ import numpy as np
15
 
16
  UPDATE_GLOBAL_ITER = 5
17
  GAMMA = 0.9
18
+ MAX_EP = 100000
19
 
20
  class Net(nn.Module):
21
+ def __init__(self, s_dim, a_dim, word_list, words_width):
22
  super(Net, self).__init__()
23
  self.s_dim = s_dim
24
  self.a_dim = a_dim
25
+ n_emb = 32
26
+ # self.pi1 = nn.Linear(s_dim, 128)
27
+ # self.pi2 = nn.Linear(128, a_dim)
28
  self.v1 = nn.Linear(s_dim, 128)
29
+ self.v2 = nn.Linear(128, n_emb)
30
+ self.v3 = nn.Linear(n_emb, 1)
31
+ set_init([ self.v1, self.v2]) # n_emb
32
  self.distribution = torch.distributions.Categorical
33
+ assert a_dim == len(word_list), "putos"
34
+ word_width = 26 * words_width
35
+ word_array = np.zeros((len(word_list), word_width))
36
+ self.actor_head = nn.Linear(n_emb, n_emb)
37
+ for i, word in enumerate(word_list):
38
+ for j, c in enumerate(word):
39
+ word_array[i, j*26 + (ord(c) - ord('A'))] = 1
40
+ self.words = torch.Tensor(word_array)
41
+ self.f_word = nn.Sequential(
42
+ nn.Linear(word_width, 128),
43
+ nn.Tanh(),
44
+ nn.Linear(128, n_emb),
45
+ )
46
 
47
  def forward(self, x):
48
+ # pi1 = torch.tanh(self.pi1(x))
49
+ fw = self.f_word(
50
+ self.words.to(x.device.index),
51
+ ).transpose(0, 1)
52
+ # logits = self.pi2(pi1)
53
  v1 = torch.tanh(self.v1(x))
54
  values = self.v2(v1)
55
+ logits = torch.log_softmax(
56
+ torch.tensordot(self.actor_head(values), fw,
57
+ dims=((1,), (0,))),
58
+ dim=-1)
59
+ values = self.v3(values)
60
  return logits, values
61
 
62
  def choose_action(self, s):
 
81
 
82
 
83
  class Worker(mp.Process):
84
+ def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width):
85
  super(Worker, self).__init__()
86
  self.name = 'w%02i' % name
87
  self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
88
  self.gnet, self.opt = gnet, opt
89
+ self.word_list = words_list
90
+ self.lnet = Net(N_S, N_A, words_list, word_width) # local network
91
+ self.env = env.unwrapped
92
 
93
  def run(self):
94
  total_step = 1
 
100
  if self.name == 'w00':
101
  self.env.render()
102
  a = self.lnet.choose_action(v_wrap(s[None, :]))
103
+ s_, r, done, _ = self.env.step(self.env.encode_word(self.word_list[a]))
 
104
  ep_r += r
105
  buffer_a.append(a)
106
  buffer_s.append(s)
 
112
  buffer_s, buffer_a, buffer_r = [], [], []
113
 
114
  if done: # done and print information
115
+ goal_word = self.env.decode_word(self.env.goal_word)
116
+ record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a])
117
  break
118
+
119
  s = s_
120
  total_step += 1
121
  self.res_queue.put(None)
a3c/utils.py CHANGED
@@ -47,7 +47,7 @@ def push_and_pull(opt, lnet, gnet, done, s_, bs, ba, br, gamma):
47
  lnet.load_state_dict(gnet.state_dict())
48
 
49
 
50
- def record(global_ep, global_ep_r, ep_r, res_queue, name):
51
  with global_ep.get_lock():
52
  global_ep.value += 1
53
  with global_ep_r.get_lock():
@@ -56,8 +56,11 @@ def record(global_ep, global_ep_r, ep_r, res_queue, name):
56
  else:
57
  global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01
58
  res_queue.put(global_ep_r.value)
59
- print(
60
- name,
61
- "Ep:", global_ep.value,
62
- "| Ep_r: %.0f" % global_ep_r.value,
63
- )
 
 
 
 
47
  lnet.load_state_dict(gnet.state_dict())
48
 
49
 
50
+ def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action):
51
  with global_ep.get_lock():
52
  global_ep.value += 1
53
  with global_ep_r.get_lock():
 
56
  else:
57
  global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01
58
  res_queue.put(global_ep_r.value)
59
+ if goal_word == action:
60
+ print(
61
+ name,
62
+ "Ep:", global_ep.value,
63
+ "| Ep_r: %.0f" % global_ep_r.value,
64
+ "| Goal :", goal_word,
65
+ "| Action: ", action
66
+ )
main.py CHANGED
@@ -10,18 +10,20 @@ from wordle_env.wordle import WordleEnvBase
10
 
11
  os.environ["OMP_NUM_THREADS"] = "1"
12
 
13
- env = gym.make('WordleEnv100OneAction-v0')
14
  N_S = env.observation_space.shape[0]
15
  N_A = env.action_space.shape[0]
16
 
17
  if __name__ == "__main__":
18
- gnet = Net(N_S, N_A) # global network
 
 
19
  gnet.share_memory() # share the global parameters in multiprocessing
20
  opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
21
  global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()
22
 
23
  # parallel training
24
- workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i, N_S = N_S, N_A=N_A) for i in range(mp.cpu_count())]
25
  [w.start() for w in workers]
26
  res = [] # record episode reward to plot
27
  while True:
 
10
 
11
  os.environ["OMP_NUM_THREADS"] = "1"
12
 
13
+ env = gym.make('WordleEnv100FullAction-v0')
14
  N_S = env.observation_space.shape[0]
15
  N_A = env.action_space.shape[0]
16
 
17
  if __name__ == "__main__":
18
+ words_list = env.words
19
+ word_width = len(env.words[0])
20
+ gnet = Net(N_S, N_A, words_list, word_width) # global network
21
  gnet.share_memory() # share the global parameters in multiprocessing
22
  opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
23
  global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()
24
 
25
  # parallel training
26
+ workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i, env, N_S = N_S, N_A=N_A, words_list=words_list, word_width=word_width) for i in range(mp.cpu_count())]
27
  [w.start() for w in workers]
28
  res = [] # record episode reward to plot
29
  while True: