wordle-solver / a3c /train.py
santit96's picture
Fix code style with black and isort
c412087
raw
history blame
2.58 kB
import os
import random
import numpy as np
import torch
import torch.multiprocessing as mp
from .net import Net
from .shared_adam import SharedAdam
from .worker import Worker
def _set_seed(seed: int = 100) -> None:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# When running on the CuDNN backend, two further options must be set
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(seed)
def train(
env,
max_ep,
model_checkpoint_dir,
gamma=0.0,
seed=100,
pretrained_model_path=None,
save=False,
min_reward=9.9,
every_n_save=100,
):
os.environ["OMP_NUM_THREADS"] = "1"
if not os.path.exists(model_checkpoint_dir):
os.makedirs(model_checkpoint_dir)
n_s = env.observation_space.shape[0]
n_a = env.action_space.n
words_list = env.words
word_width = len(env.words[0])
# Set global seeds for randoms
_set_seed(seed)
gnet = Net(n_s, n_a, words_list, word_width) # global network
if pretrained_model_path:
gnet.load_state_dict(torch.load(pretrained_model_path))
gnet.share_memory() # share the global parameters in multiprocessing
opt = SharedAdam(
gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)
) # global optimizer
global_ep, global_ep_r, res_queue, win_ep = (
mp.Value("i", 0),
mp.Value("d", 0.0),
mp.Queue(),
mp.Value("i", 0),
)
# parallel training
workers = [
Worker(
max_ep,
gnet,
opt,
global_ep,
global_ep_r,
res_queue,
i,
env,
n_s,
n_a,
words_list,
word_width,
win_ep,
model_checkpoint_dir,
gamma,
pretrained_model_path,
save,
min_reward,
every_n_save,
)
for i in range(mp.cpu_count())
]
[w.start() for w in workers]
res = [] # record episode reward to plot
while True:
r = res_queue.get()
if r is not None:
res.append(r)
else:
break
[w.join() for w in workers]
if save:
torch.save(
gnet.state_dict(),
os.path.join(model_checkpoint_dir, f"model_{env.unwrapped.spec.id}.pth"),
)
return global_ep, win_ep, gnet, res