Module TeachMyAgent.students.openai_baselines.common.vec_env.vec_frame_stack
Expand source code
from .vec_env import VecEnvWrapper
import numpy as np
from gym import spaces
class VecFrameStack(VecEnvWrapper):
def __init__(self, venv, nstack):
self.venv = venv
self.nstack = nstack
wos = venv.observation_space # wrapped ob space
low = np.repeat(wos.low, self.nstack, axis=-1)
high = np.repeat(wos.high, self.nstack, axis=-1)
self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
def step_wait(self):
obs, rews, news, infos = self.venv.step_wait()
self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1)
for (i, new) in enumerate(news):
if new:
self.stackedobs[i] = 0
self.stackedobs[..., -obs.shape[-1]:] = obs
return self.stackedobs, rews, news, infos
def reset(self):
obs = self.venv.reset()
self.stackedobs[...] = 0
self.stackedobs[..., -obs.shape[-1]:] = obs
return self.stackedobs
Classes
class VecFrameStack (venv, nstack)
-
An environment wrapper that applies to an entire batch of environments at once.
Expand source code
class VecFrameStack(VecEnvWrapper): def __init__(self, venv, nstack): self.venv = venv self.nstack = nstack wos = venv.observation_space # wrapped ob space low = np.repeat(wos.low, self.nstack, axis=-1) high = np.repeat(wos.high, self.nstack, axis=-1) self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) VecEnvWrapper.__init__(self, venv, observation_space=observation_space) def step_wait(self): obs, rews, news, infos = self.venv.step_wait() self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1) for (i, new) in enumerate(news): if new: self.stackedobs[i] = 0 self.stackedobs[..., -obs.shape[-1]:] = obs return self.stackedobs, rews, news, infos def reset(self): obs = self.venv.reset() self.stackedobs[...] = 0 self.stackedobs[..., -obs.shape[-1]:] = obs return self.stackedobs
Ancestors
- VecEnvWrapper
- VecEnv
- abc.ABC
Inherited members