(self)
| 211 | class NHNetTest(tf.test.TestCase, parameterized.TestCase): |
| 212 | |
| 213 | def setUp(self): |
| 214 | super(NHNetTest, self).setUp() |
| 215 | self._nhnet_config = configs.NHNetConfig() |
| 216 | self._nhnet_config.override(utils.get_test_params().as_dict()) |
| 217 | self._bert2bert_config = configs.BERT2BERTConfig() |
| 218 | self._bert2bert_config.override(utils.get_test_params().as_dict()) |
| 219 | |
| 220 | def _count_params(self, layer, trainable_only=True): |
| 221 | """Returns the count of all model parameters, or just trainable ones.""" |