Module TeachMyAgent.students.openai_baselines.common.vec_env.subproc_vec_env
Expand source code
import multiprocessing as mp
import numpy as np
from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
def worker(remote, parent_remote, env_fn_wrappers):
def step_env(env, action):
ob, reward, done, info = env.step(action)
if done:
ob = env.reset()
return ob, reward, done, info
parent_remote.close()
envs = [env_fn_wrapper() for env_fn_wrapper in env_fn_wrappers.x]
try:
while True:
cmd, data = remote.recv()
if cmd == 'step':
remote.send([step_env(env, action) for env, action in zip(envs, data)])
elif cmd == 'reset':
remote.send([env.reset() for env in envs])
elif cmd == 'render':
remote.send([env.render(mode='rgb_array') for env in envs])
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces_spec':
remote.send(CloudpickleWrapper((envs[0].observation_space, envs[0].action_space, envs[0].spec)))
else:
raise NotImplementedError
except KeyboardInterrupt:
print('SubprocVecEnv worker: got KeyboardInterrupt')
finally:
for env in envs:
env.close()
class SubprocVecEnv(VecEnv):
"""
VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
Recommended to use when num_envs > 1 and step() can be a bottleneck.
"""
def __init__(self, env_fns, spaces=None, context='spawn', in_series=1):
"""
Arguments:
env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable
in_series: number of environments to run in series in a single process
(e.g. when len(env_fns) == 12 and in_series == 3, it will run 4 processes, each running 3 envs in series)
"""
self.waiting = False
self.closed = False
self.in_series = in_series
nenvs = len(env_fns)
assert nenvs % in_series == 0, "Number of envs must be divisible by number of envs to run in series"
self.nremotes = nenvs // in_series
env_fns = np.array_split(env_fns, self.nremotes)
ctx = mp.get_context(context)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.nremotes)])
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
with clear_mpi_env_vars():
p.start()
for remote in self.work_remotes:
remote.close()
self.remotes[0].send(('get_spaces_spec', None))
observation_space, action_space, self.spec = self.remotes[0].recv().x
self.viewer = None
VecEnv.__init__(self, nenvs, observation_space, action_space)
def step_async(self, actions):
self._assert_not_closed()
actions = np.array_split(actions, self.nremotes)
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True
def step_wait(self):
self._assert_not_closed()
results = [remote.recv() for remote in self.remotes]
results = _flatten_list(results)
self.waiting = False
obs, rews, dones, infos = zip(*results)
return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
self._assert_not_closed()
for remote in self.remotes:
remote.send(('reset', None))
obs = [remote.recv() for remote in self.remotes]
obs = _flatten_list(obs)
return _flatten_obs(obs)
def close_extras(self):
self.closed = True
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
def get_images(self):
self._assert_not_closed()
for pipe in self.remotes:
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes]
imgs = _flatten_list(imgs)
return imgs
def _assert_not_closed(self):
assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
def __del__(self):
if not self.closed:
self.close()
def _flatten_obs(obs):
assert isinstance(obs, (list, tuple))
assert len(obs) > 0
if isinstance(obs[0], dict):
keys = obs[0].keys()
return {k: np.stack([o[k] for o in obs]) for k in keys}
else:
return np.stack(obs)
def _flatten_list(l):
assert isinstance(l, (list, tuple))
assert len(l) > 0
assert all([len(l_) > 0 for l_ in l])
return [l__ for l_ in l for l__ in l_]
Functions
def worker(remote, parent_remote, env_fn_wrappers)
-
Expand source code
def worker(remote, parent_remote, env_fn_wrappers): def step_env(env, action): ob, reward, done, info = env.step(action) if done: ob = env.reset() return ob, reward, done, info parent_remote.close() envs = [env_fn_wrapper() for env_fn_wrapper in env_fn_wrappers.x] try: while True: cmd, data = remote.recv() if cmd == 'step': remote.send([step_env(env, action) for env, action in zip(envs, data)]) elif cmd == 'reset': remote.send([env.reset() for env in envs]) elif cmd == 'render': remote.send([env.render(mode='rgb_array') for env in envs]) elif cmd == 'close': remote.close() break elif cmd == 'get_spaces_spec': remote.send(CloudpickleWrapper((envs[0].observation_space, envs[0].action_space, envs[0].spec))) else: raise NotImplementedError except KeyboardInterrupt: print('SubprocVecEnv worker: got KeyboardInterrupt') finally: for env in envs: env.close()
Classes
class SubprocVecEnv (env_fns, spaces=None, context='spawn', in_series=1)
-
VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes. Recommended to use when num_envs > 1 and step() can be a bottleneck.
Arguments:
env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable in_series: number of environments to run in series in a single process (e.g. when len(env_fns) == 12 and in_series == 3, it will run 4 processes, each running 3 envs in series)
Expand source code
class SubprocVecEnv(VecEnv): """ VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes. Recommended to use when num_envs > 1 and step() can be a bottleneck. """ def __init__(self, env_fns, spaces=None, context='spawn', in_series=1): """ Arguments: env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable in_series: number of environments to run in series in a single process (e.g. when len(env_fns) == 12 and in_series == 3, it will run 4 processes, each running 3 envs in series) """ self.waiting = False self.closed = False self.in_series = in_series nenvs = len(env_fns) assert nenvs % in_series == 0, "Number of envs must be divisible by number of envs to run in series" self.nremotes = nenvs // in_series env_fns = np.array_split(env_fns, self.nremotes) ctx = mp.get_context(context) self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.nremotes)]) self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] for p in self.ps: p.daemon = True # if the main process crashes, we should not cause things to hang with clear_mpi_env_vars(): p.start() for remote in self.work_remotes: remote.close() self.remotes[0].send(('get_spaces_spec', None)) observation_space, action_space, self.spec = self.remotes[0].recv().x self.viewer = None VecEnv.__init__(self, nenvs, observation_space, action_space) def step_async(self, actions): self._assert_not_closed() actions = np.array_split(actions, self.nremotes) for remote, action in zip(self.remotes, actions): remote.send(('step', action)) self.waiting = True def step_wait(self): self._assert_not_closed() results = [remote.recv() for remote in self.remotes] results = _flatten_list(results) self.waiting = False obs, rews, dones, infos = zip(*results) return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos def reset(self): self._assert_not_closed() for remote in self.remotes: remote.send(('reset', None)) obs = [remote.recv() for remote in self.remotes] obs = _flatten_list(obs) return _flatten_obs(obs) def close_extras(self): self.closed = True if self.waiting: for remote in self.remotes: remote.recv() for remote in self.remotes: remote.send(('close', None)) for p in self.ps: p.join() def get_images(self): self._assert_not_closed() for pipe in self.remotes: pipe.send(('render', None)) imgs = [pipe.recv() for pipe in self.remotes] imgs = _flatten_list(imgs) return imgs def _assert_not_closed(self): assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()" def __del__(self): if not self.closed: self.close()
Ancestors
- VecEnv
- abc.ABC
Inherited members