MCPcopy Index your code
hub / github.com/ddbourgin/numpy-ml / obs_stats

Function obs_stats

numpy_ml/rl_models/rl_utils.py:382–427  ·  view source on GitHub ↗

Get information on the observation space for `env`. Parameters ---------- env : ``gym.wrappers`` or ``gym.envs`` instance The environment to evaluate. md_obs : bool Whether the `env`'s action space is multidimensional. cont_obs : bool Whether the `en

(env, md_obs, cont_obs)

Source from the content-addressed store, hash-verified

380
381
382def obs_stats(env, md_obs, cont_obs):
383 """
384 Get information on the observation space for `env`.
385
386 Parameters
387 ----------
388 env : ``gym.wrappers`` or ``gym.envs`` instance
389 The environment to evaluate.
390 md_obs : bool
391 Whether the `env`'s action space is multidimensional.
392 cont_obs : bool
393 Whether the `env`'s observation space is multidimensional.
394
395 Returns
396 -------
397 n_obs_per_dim : list of length (obs_dim,)
398 The number of possible observation classes for each dimension of the
399 observation space.
400 obs_ids : list or None
401 A list of all valid observations within the space. If `cont_obs` is
402 True, this value will be None.
403 obs_dim : int or None
404 The number of dimensions in a single observation.
405 """
406 if cont_obs:
407 obs_ids = None
408 obs_dim = env.observation_space.shape[0]
409 n_obs_per_dim = [np.inf for _ in range(obs_dim)]
410 else:
411 if md_obs:
412 n_obs_per_dim = [
413 space.n if hasattr(space, "n") else np.inf
414 for space in env.observation_space.spaces
415 ]
416 obs_ids = (
417 None
418 if np.inf in n_obs_per_dim
419 else list(product(*[range(i) for i in n_obs_per_dim]))
420 )
421 obs_dim = len(n_obs_per_dim)
422 else:
423 obs_dim = 1
424 n_obs_per_dim = [env.observation_space.n]
425 obs_ids = list(range(n_obs_per_dim[0]))
426
427 return n_obs_per_dim, obs_ids, obs_dim
428
429
430def env_stats(env):

Callers 1

env_statsFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected