MCPcopy Index your code
hub / github.com/ddbourgin/numpy-ml / EnvModel

Class EnvModel

numpy_ml/rl_models/rl_utils.py:28–121  ·  view source on GitHub ↗

A simple tabular environment model that maintains the counts of each reward-outcome pair given the state and action that preceded them. The model can be queried with >>> M = EnvModel() >>> M[(state, action, reward, next_state)] += 1 >>> M[(state, action, reward,

Source from the content-addressed store, hash-verified

26
27
28class EnvModel(object):
29 """
30 A simple tabular environment model that maintains the counts of each
31 reward-outcome pair given the state and action that preceded them. The
32 model can be queried with
33
34 >>> M = EnvModel()
35 >>> M[(state, action, reward, next_state)] += 1
36 >>> M[(state, action, reward, next_state)]
37 1
38 >>> M.state_action_pairs()
39 [(state, action)]
40 >>> M.outcome_probs(state, action)
41 [(next_state, 1)]
42 """
43
44 def __init__(self):
45 super(EnvModel, self).__init__()
46 self._model = defaultdict(lambda: defaultdict(lambda: 0))
47
48 def __setitem__(self, key, value):
49 """Set self[key] to value"""
50 s, a, r, s_ = key
51 self._model[(s, a)][(r, s_)] = value
52
53 def __getitem__(self, key):
54 """Return the value associated with key"""
55 s, a, r, s_ = key
56 return self._model[(s, a)][(r, s_)]
57
58 def __contains__(self, key):
59 """True if EnvModel contains `key`, else False"""
60 s, a, r, s_ = key
61 p1 = (s, a) in self.state_action_pairs()
62 p2 = (r, s_) in self.reward_outcome_pairs()
63 return p1 and p2
64
65 def state_action_pairs(self):
66 """Return all (state, action) pairs in the environment model"""
67 return list(self._model.keys())
68
69 def reward_outcome_pairs(self, s, a):
70 """
71 Return all (reward, next_state) pairs associated with taking action `a`
72 in state `s`.
73 """
74 return list(self._model[(s, a)].keys())
75
76 def outcome_probs(self, s, a):
77 """
78 Return the probability under the environment model of each outcome
79 state after taking action `a` in state `s`.
80
81 Parameters
82 ----------
83 s : int as returned by ``self._obs2num``
84 The id for the state/observation.
85 a : int as returned by ``self._action2num``

Callers 1

_init_paramsMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected