File size: 4,328 Bytes
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
abff1ef
44db2f9
 
 
abff1ef
44db2f9
 
abff1ef
44db2f9
 
 
abff1ef
 
 
44db2f9
abff1ef
 
 
44db2f9
abff1ef
 
 
 
 
 
 
 
 
 
 
 
 
44db2f9
 
abff1ef
 
 
 
 
44db2f9
 
abff1ef
 
 
 
 
44db2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abff1ef
44db2f9
 
 
 
abff1ef
 
 
44db2f9
 
 
 
 
 
 
 
 
 
 
abff1ef
44db2f9
 
 
 
 
 
 
 
 
 
 
abff1ef
 
44db2f9
abff1ef
44db2f9
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
Reinforcement Learning (A3C) using Pytroch + multiprocessing.
The most simple implementation for continuous action.

View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import torch.multiprocessing as mp
from .utils import v_wrap, set_init, push_and_pull, record
import numpy as np

UPDATE_GLOBAL_ITER = 5
GAMMA = 0.9
MAX_EP = 100000

class Net(nn.Module):
    def __init__(self, s_dim, a_dim, word_list, words_width):
        super(Net, self).__init__()
        self.s_dim = s_dim
        self.a_dim = a_dim
        n_emb = 32
        # self.pi1 = nn.Linear(s_dim, 128)
        # self.pi2 = nn.Linear(128, a_dim)
        self.v1 = nn.Linear(s_dim, 128)
        self.v2 = nn.Linear(128, n_emb)
        self.v3 = nn.Linear(n_emb, 1)
        set_init([ self.v1, self.v2]) # n_emb
        self.distribution = torch.distributions.Categorical
        assert a_dim == len(word_list), "putos"
        word_width = 26 * words_width
        word_array = np.zeros((len(word_list), word_width))
        self.actor_head = nn.Linear(n_emb, n_emb)
        for i, word in enumerate(word_list):
            for j, c in enumerate(word):
                word_array[i, j*26 + (ord(c) - ord('A'))] = 1
        self.words = torch.Tensor(word_array)
        self.f_word = nn.Sequential(
            nn.Linear(word_width, 128),
            nn.Tanh(),
            nn.Linear(128, n_emb),
        )

    def forward(self, x):
        # pi1 = torch.tanh(self.pi1(x))
        fw = self.f_word(
            self.words.to(x.device.index),
        ).transpose(0, 1)
        # logits = self.pi2(pi1)
        v1 = torch.tanh(self.v1(x))
        values = self.v2(v1)
        logits = torch.log_softmax(
            torch.tensordot(self.actor_head(values), fw,
                            dims=((1,), (0,))),
            dim=-1)
        values = self.v3(values)
        return logits, values

    def choose_action(self, s):
        self.eval()
        logits, _ = self.forward(s)
        prob = F.softmax(logits, dim=1).data
        m = self.distribution(prob)
        return m.sample().numpy()[0]

    def loss_func(self, s, a, v_t):
        self.train()
        logits, values = self.forward(s)
        td = v_t - values
        c_loss = td.pow(2)
        
        probs = F.softmax(logits, dim=1)
        m = self.distribution(probs)
        exp_v = m.log_prob(a) * td.detach().squeeze()
        a_loss = -exp_v
        total_loss = (c_loss + a_loss).mean()
        return total_loss


class Worker(mp.Process):
    def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width):
        super(Worker, self).__init__()
        self.name = 'w%02i' % name
        self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
        self.gnet, self.opt = gnet, opt
        self.word_list = words_list
        self.lnet = Net(N_S, N_A, words_list, word_width)           # local network
        self.env = env.unwrapped

    def run(self):
        total_step = 1
        while self.g_ep.value < MAX_EP:
            s = self.env.reset()
            buffer_s, buffer_a, buffer_r = [], [], []
            ep_r = 0.
            while True:
                if self.name == 'w00':
                    self.env.render()
                a = self.lnet.choose_action(v_wrap(s[None, :]))
                s_, r, done, _ = self.env.step(self.env.encode_word(self.word_list[a]))
                ep_r += r
                buffer_a.append(a)
                buffer_s.append(s)
                buffer_r.append(r)

                if total_step % UPDATE_GLOBAL_ITER == 0 or done:  # update global and assign to local net
                    # sync
                    push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
                    buffer_s, buffer_a, buffer_r = [], [], []

                    if done:  # done and print information
                        goal_word = self.env.decode_word(self.env.goal_word)
                        record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a])
                        break

                s = s_
                total_step += 1
        self.res_queue.put(None)