Ensures that layers can be constructed and forward-props can run.
(
self,
input_shape: Sequence[int],
input_dtype: Optional[tf.DType] = tf.float32,
**kwargs
)
| 94 | ), |
| 95 | ) |
| 96 | def testForward( |
| 97 | self, |
| 98 | input_shape: Sequence[int], |
| 99 | input_dtype: Optional[tf.DType] = tf.float32, |
| 100 | **kwargs |
| 101 | ) -> None: |
| 102 | """Ensures that layers can be constructed and forward-props can run.""" |
| 103 | |
| 104 | inp = tf.random.uniform( |
| 105 | input_shape, |
| 106 | minval=-1.0, |
| 107 | maxval=1.0, |
| 108 | dtype=input_dtype, |
| 109 | ) |
| 110 | |
| 111 | model = maxvit.MaxViT(**kwargs) |
| 112 | out = model(inp, training=kwargs.get('training', None)) |
| 113 | |
| 114 | add_gap_layer_norm = kwargs.get('add_gap_layer_norm', False) |
| 115 | if add_gap_layer_norm: |
| 116 | self.assertAllEqual([input_shape[0], kwargs['representation_size']], |
| 117 | out['pre_logits'].get_shape().as_list()) |
| 118 | |
| 119 | # Remove `pre_logits` if exists. |
| 120 | out.pop('pre_logits', None) |
| 121 | out = out[max(out.keys())] |
| 122 | self.assertAllEqual(kwargs['expected_shape'], out.get_shape().as_list()) |
| 123 | self.assertDTypeEqual(tf.reduce_mean(out).numpy(), np.float32) |
| 124 | |
| 125 | def testBuildMaxViTWithConfig(self): |
| 126 | backbone_config = backbones.Backbone( |