MCPcopy Index your code
hub / github.com/pytorch/pytorch / test_get_params

Method test_get_params

caffe2/python/brew_test.py:200–227  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

198 self.assertEqual(model.arg_scope['order'], 'NHWC')
199
200 def test_get_params(self):
201 def param(x):
202 return core.ScopedBlobReference(x)
203
204 def to_str_list(x):
205 return sorted([str(p) for p in x])
206
207 model = ModelHelper(name="test_model")
208 model.AddParameter(param("a"))
209 model.AddParameter(param("b"), tags=ParameterTags.COMPUTED_PARAM)
210 with scope.NameScope("c"):
211 model.AddParameter(param("a"))
212 model.AddParameter(param("d"), tags=ParameterTags.COMPUTED_PARAM)
213 self.assertEqual(to_str_list(model.GetParams()), ['c/a'])
214 self.assertEqual(to_str_list(model.GetComputedParams()), ['c/d'])
215 self.assertEqual(to_str_list(model.GetAllParams()), ['c/a', 'c/d'])
216 # Get AllParams from the global Scope
217 self.assertEqual(to_str_list(model.GetAllParams('')), [
218 'a', 'b', 'c/a', 'c/d'])
219 self.assertEqual(to_str_list(model.GetParams()), ['a', 'c/a'])
220 self.assertEqual(to_str_list(model.GetComputedParams()), ['b', 'c/d'])
221 self.assertEqual(to_str_list(model.GetAllParams()),
222 ['a', 'b', 'c/a', 'c/d'])
223 self.assertEqual(to_str_list(model.GetAllParams('')),
224 ['a', 'b', 'c/a', 'c/d'])
225 # Get AllParams from the scope 'c'
226 self.assertEqual(to_str_list(model.GetAllParams('c')), ['c/a', 'c/d'])
227 self.assertEqual(to_str_list(model.GetAllParams('c/')), ['c/a', 'c/d'])
228
229 def test_param_consistence(self):
230 model = ModelHelper(name='test_mode')

Callers

nothing calls this directly

Calls 6

AddParameterMethod · 0.95
GetParamsMethod · 0.95
GetComputedParamsMethod · 0.95
GetAllParamsMethod · 0.95
ModelHelperClass · 0.90
assertEqualMethod · 0.45

Tested by

no test coverage detected