File size: 4,065 Bytes
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
abff1ef
44db2f9
 
 
 
abff1ef
44db2f9
 
 
abff1ef
 
 
f05ece6
 
abff1ef
 
44db2f9
abff1ef
 
 
 
 
 
 
 
f05ece6
abff1ef
f05ece6
abff1ef
44db2f9
 
abff1ef
 
 
 
 
44db2f9
 
abff1ef
 
 
 
 
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f05ece6
44db2f9
f05ece6
44db2f9
62c6c3b
44db2f9
abff1ef
 
 
44db2f9
 
f05ece6
44db2f9
 
 
 
 
f05ece6
44db2f9
 
 
 
 
f05ece6
44db2f9
 
f05ece6
 
62c6c3b
f05ece6
44db2f9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Reinforcement Learning (A3C) using Pytroch + multiprocessing.
The most simple implementation for continuous action.

View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import torch.multiprocessing as mp
from .utils import v_wrap, set_init, push_and_pull, record
import numpy as np

GAMMA = 0.9

class Net(nn.Module):
    def __init__(self, s_dim, a_dim, word_list, words_width):
        super(Net, self).__init__()
        self.s_dim = s_dim
        self.a_dim = a_dim
        n_emb = 32
        # self.pi1 = nn.Linear(s_dim, 128)
        # self.pi2 = nn.Linear(128, a_dim)
        self.v1 = nn.Linear(s_dim, 256)
        self.v2 = nn.Linear(256, n_emb)
        self.v3 = nn.Linear(n_emb, 1)
        set_init([ self.v1, self.v2]) # n_emb
        self.distribution = torch.distributions.Categorical
        word_width = 26 * words_width
        word_array = np.zeros((len(word_list), word_width))
        self.actor_head = nn.Linear(n_emb, n_emb)
        for i, word in enumerate(word_list):
            for j, c in enumerate(word):
                word_array[i, j*26 + (ord(c) - ord('A'))] = 1
        self.words = torch.Tensor(word_array)
        self.f_word = nn.Sequential(
            nn.Linear(word_width, 64),
            nn.Tanh(),
            nn.Linear(64, n_emb),
        )

    def forward(self, x):
        # pi1 = torch.tanh(self.pi1(x))
        fw = self.f_word(
            self.words.to(x.device.index),
        ).transpose(0, 1)
        # logits = self.pi2(pi1)
        v1 = torch.tanh(self.v1(x))
        values = self.v2(v1)
        logits = torch.log_softmax(
            torch.tensordot(self.actor_head(values), fw,
                            dims=((1,), (0,))),
            dim=-1)
        values = self.v3(values)
        return logits, values

    def choose_action(self, s):
        self.eval()
        logits, _ = self.forward(s)
        prob = F.softmax(logits, dim=1).data
        m = self.distribution(prob)
        return m.sample().numpy()[0]

    def loss_func(self, s, a, v_t):
        self.train()
        logits, values = self.forward(s)
        td = v_t - values
        c_loss = td.pow(2)
        
        probs = F.softmax(logits, dim=1)
        m = self.distribution(probs)
        exp_v = m.log_prob(a) * td.detach().squeeze()
        a_loss = -exp_v
        total_loss = (c_loss + a_loss).mean()
        return total_loss


class Worker(mp.Process):
    def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
        super(Worker, self).__init__()
        self.max_ep = max_ep
        self.name = 'w%02i' % name
        self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
        self.gnet, self.opt = gnet, opt
        self.word_list = words_list
        self.lnet = Net(N_S, N_A, words_list, word_width)           # local network
        self.env = env.unwrapped

    def run(self):
        while self.g_ep.value < self.max_ep:
            s = self.env.reset()
            buffer_s, buffer_a, buffer_r = [], [], []
            ep_r = 0.
            while True:
                a = self.lnet.choose_action(v_wrap(s[None, :]))
                s_, r, done, _ = self.env.step(a)
                ep_r += r
                buffer_a.append(a)
                buffer_s.append(s)
                buffer_r.append(r)

                if done:  # update global and assign to local net
                    # sync
                    push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
                    goal_word = self.word_list[self.env.goal_word]
                    record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
                    buffer_s, buffer_a, buffer_r = [], [], []
                    break
                s = s_
        self.res_queue.put(None)