(self)
| 198 | self.assertEqual(model.arg_scope['order'], 'NHWC') |
| 199 | |
| 200 | def test_get_params(self): |
| 201 | def param(x): |
| 202 | return core.ScopedBlobReference(x) |
| 203 | |
| 204 | def to_str_list(x): |
| 205 | return sorted([str(p) for p in x]) |
| 206 | |
| 207 | model = ModelHelper(name="test_model") |
| 208 | model.AddParameter(param("a")) |
| 209 | model.AddParameter(param("b"), tags=ParameterTags.COMPUTED_PARAM) |
| 210 | with scope.NameScope("c"): |
| 211 | model.AddParameter(param("a")) |
| 212 | model.AddParameter(param("d"), tags=ParameterTags.COMPUTED_PARAM) |
| 213 | self.assertEqual(to_str_list(model.GetParams()), ['c/a']) |
| 214 | self.assertEqual(to_str_list(model.GetComputedParams()), ['c/d']) |
| 215 | self.assertEqual(to_str_list(model.GetAllParams()), ['c/a', 'c/d']) |
| 216 | # Get AllParams from the global Scope |
| 217 | self.assertEqual(to_str_list(model.GetAllParams('')), [ |
| 218 | 'a', 'b', 'c/a', 'c/d']) |
| 219 | self.assertEqual(to_str_list(model.GetParams()), ['a', 'c/a']) |
| 220 | self.assertEqual(to_str_list(model.GetComputedParams()), ['b', 'c/d']) |
| 221 | self.assertEqual(to_str_list(model.GetAllParams()), |
| 222 | ['a', 'b', 'c/a', 'c/d']) |
| 223 | self.assertEqual(to_str_list(model.GetAllParams('')), |
| 224 | ['a', 'b', 'c/a', 'c/d']) |
| 225 | # Get AllParams from the scope 'c' |
| 226 | self.assertEqual(to_str_list(model.GetAllParams('c')), ['c/a', 'c/d']) |
| 227 | self.assertEqual(to_str_list(model.GetAllParams('c/')), ['c/a', 'c/d']) |
| 228 | |
| 229 | def test_param_consistence(self): |
| 230 | model = ModelHelper(name='test_mode') |
nothing calls this directly
no test coverage detected