File size: 4,081 Bytes
254d61f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a7031
254d61f
 
 
 
 
 
 
 
18a7031
 
254d61f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570282c
254d61f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Worker class implementation of the a3c discrete algorithm
"""
import os
import torch
import numpy as np
import torch.multiprocessing as mp
from torch import nn
from .net import Net
from .utils import v_wrap


GAMMA = 0.65


class Worker(mp.Process):
    def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir, pretrained_model_path=None):
        super(Worker, self).__init__()
        self.max_ep = max_ep
        self.name = 'w%02i' % name
        self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
        self.gnet, self.opt = gnet, opt
        self.word_list = words_list
        # local network
        self.lnet = Net(N_S, N_A, words_list, word_width)
        if pretrained_model_path:
            self.lnet.load_state_dict(torch.load(pretrained_model_path))
        self.env = env.unwrapped
        self.model_checkpoint_dir = model_checkpoint_dir

    def run(self):
        while self.g_ep.value < self.max_ep:
            s = self.env.reset()
            buffer_s, buffer_a, buffer_r = [], [], []
            ep_r = 0.
            while True:
                a = self.lnet.choose_action(v_wrap(s[None, :]))
                s_, r, done, _ = self.env.step(a)
                ep_r += r
                buffer_a.append(a)
                buffer_s.append(s)
                buffer_r.append(r)

                if done:  # update global and assign to local net
                    # sync
                    self.push_and_pull(done, s_, buffer_s,
                                       buffer_a, buffer_r, GAMMA)
                    goal_word = self.word_list[self.env.goal_word]
                    self.record(ep_r, goal_word,
                                self.word_list[a], len(buffer_a))
                    self.save_model()
                    buffer_s, buffer_a, buffer_r = [], [], []
                    break
                s = s_
        self.res_queue.put(None)

    def push_and_pull(self, done, s_, bs, ba, br, gamma):
        if done:
            v_s_ = 0.               # terminal
        else:
            v_s_ = self.lnet.forward(v_wrap(
                s_[None, :]))[-1].data.numpy()[0, 0]

        buffer_v_target = []
        for r in br[::-1]:    # reverse buffer r
            v_s_ = r + gamma * v_s_
            buffer_v_target.append(v_s_)
        buffer_v_target.reverse()

        loss = self.lnet.loss_func(
            v_wrap(np.vstack(bs)),
            v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)),
            v_wrap(np.array(buffer_v_target)[:, None]))

        # calculate local gradients and push local parameters to global
        self.opt.zero_grad()
        loss.backward()
        for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()):
            gp._grad = lp.grad
        self.opt.step()

        # pull global parameters
        self.lnet.load_state_dict(self.gnet.state_dict())

    def save_model(self):
        if self.g_ep_r.value >= 9.9 and self.g_ep.value % 100 == 0:
            torch.save(self.gnet.state_dict(), os.path.join(
                self.model_checkpoint_dir, f'model_{ self.g_ep.value}.pth'))

    def record(self, ep_r, goal_word, action, action_number):
        with self.g_ep.get_lock():
            self.g_ep.value += 1
        with self.g_ep_r.get_lock():
            if self.g_ep_r.value == 0.:
                self.g_ep_r.value = ep_r
            else:
                self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
        self.res_queue.put(self.g_ep_r.value)
        if goal_word == action:
            self.winning_ep.value += 1
            if self.g_ep.value % 100 == 0:
                print(
                    self.name,
                    "Ep:", self.g_ep.value,
                    "| Ep_r: %.0f" % self.g_ep_r.value,
                    "| Goal :", goal_word,
                    "| Action: ", action,
                    "| Actions: ", action_number
                )