Spaces:
Sleeping
Sleeping
File size: 2,576 Bytes
a777e34 23fd1ff c412087 a777e34 c412087 a777e34 c412087 a777e34 23fd1ff c10a05f c412087 c10a05f c412087 c10a05f a777e34 23fd1ff a777e34 c412087 a777e34 c10a05f c412087 c10a05f a777e34 fa34b1d c412087 a777e34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import os
import random
import numpy as np
import torch
import torch.multiprocessing as mp
from .net import Net
from .shared_adam import SharedAdam
from .worker import Worker
def _set_seed(seed: int = 100) -> None:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# When running on the CuDNN backend, two further options must be set
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(seed)
def train(
env,
max_ep,
model_checkpoint_dir,
gamma=0.0,
seed=100,
pretrained_model_path=None,
save=False,
min_reward=9.9,
every_n_save=100,
):
os.environ["OMP_NUM_THREADS"] = "1"
if not os.path.exists(model_checkpoint_dir):
os.makedirs(model_checkpoint_dir)
n_s = env.observation_space.shape[0]
n_a = env.action_space.n
words_list = env.words
word_width = len(env.words[0])
# Set global seeds for randoms
_set_seed(seed)
gnet = Net(n_s, n_a, words_list, word_width) # global network
if pretrained_model_path:
gnet.load_state_dict(torch.load(pretrained_model_path))
gnet.share_memory() # share the global parameters in multiprocessing
opt = SharedAdam(
gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)
) # global optimizer
global_ep, global_ep_r, res_queue, win_ep = (
mp.Value("i", 0),
mp.Value("d", 0.0),
mp.Queue(),
mp.Value("i", 0),
)
# parallel training
workers = [
Worker(
max_ep,
gnet,
opt,
global_ep,
global_ep_r,
res_queue,
i,
env,
n_s,
n_a,
words_list,
word_width,
win_ep,
model_checkpoint_dir,
gamma,
pretrained_model_path,
save,
min_reward,
every_n_save,
)
for i in range(mp.cpu_count())
]
[w.start() for w in workers]
res = [] # record episode reward to plot
while True:
r = res_queue.get()
if r is not None:
res.append(r)
else:
break
[w.join() for w in workers]
if save:
torch.save(
gnet.state_dict(),
os.path.join(model_checkpoint_dir, f"model_{env.unwrapped.spec.id}.pth"),
)
return global_ep, win_ep, gnet, res
|