Module TeachMyAgent.students.openai_baselines.common.vec_env.test_video_recorder

Tests for asynchronous vectorized environments.

Expand source code
"""
Tests for asynchronous vectorized environments.
"""

import gym
import pytest
import os
import glob
import tempfile

from .dummy_vec_env import DummyVecEnv
from .shmem_vec_env import ShmemVecEnv
from .subproc_vec_env import SubprocVecEnv
from .vec_video_recorder import VecVideoRecorder

@pytest.mark.parametrize('klass', (DummyVecEnv, ShmemVecEnv, SubprocVecEnv))
@pytest.mark.parametrize('num_envs', (1, 4))
@pytest.mark.parametrize('video_length', (10, 100))
@pytest.mark.parametrize('video_interval', (1, 50))
def test_video_recorder(klass, num_envs, video_length, video_interval):
    """
    Wrap an existing VecEnv with VevVideoRecorder,
    Make (video_interval + video_length + 1) steps,
    then check that the file is present
    """

    def make_fn():
        env = gym.make('PongNoFrameskip-v4')
        return env
    fns = [make_fn for _ in range(num_envs)]
    env = klass(fns)

    with tempfile.TemporaryDirectory() as video_path:
        env = VecVideoRecorder(env, video_path, record_video_trigger=lambda x: x % video_interval == 0, video_length=video_length)

        env.reset()
        for _ in range(video_interval + video_length + 1):
            env.step([0] * num_envs)
        env.close()


        recorded_video = glob.glob(os.path.join(video_path, "*.mp4"))

        # first and second step
        assert len(recorded_video) == 2
        # Files are not empty
        assert all(os.stat(p).st_size != 0 for p in recorded_video)

Functions

def test_video_recorder(klass, num_envs, video_length, video_interval)

Wrap an existing VecEnv with VevVideoRecorder, Make (video_interval + video_length + 1) steps, then check that the file is present

Expand source code
@pytest.mark.parametrize('klass', (DummyVecEnv, ShmemVecEnv, SubprocVecEnv))
@pytest.mark.parametrize('num_envs', (1, 4))
@pytest.mark.parametrize('video_length', (10, 100))
@pytest.mark.parametrize('video_interval', (1, 50))
def test_video_recorder(klass, num_envs, video_length, video_interval):
    """
    Wrap an existing VecEnv with VevVideoRecorder,
    Make (video_interval + video_length + 1) steps,
    then check that the file is present
    """

    def make_fn():
        env = gym.make('PongNoFrameskip-v4')
        return env
    fns = [make_fn for _ in range(num_envs)]
    env = klass(fns)

    with tempfile.TemporaryDirectory() as video_path:
        env = VecVideoRecorder(env, video_path, record_video_trigger=lambda x: x % video_interval == 0, video_length=video_length)

        env.reset()
        for _ in range(video_interval + video_length + 1):
            env.step([0] * num_envs)
        env.close()


        recorded_video = glob.glob(os.path.join(video_path, "*.mp4"))

        # first and second step
        assert len(recorded_video) == 2
        # Files are not empty
        assert all(os.stat(p).st_size != 0 for p in recorded_video)