(self)
| 114 | # TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE |
| 115 | @unittest.skipIf(CI and Device.DEFAULT == "DSP", "failing because of make things that can't be images not images") |
| 116 | def test_mnist(self): |
| 117 | layers = [ |
| 118 | nn.Conv2d(1, 32, 5), Tensor.relu, |
| 119 | nn.Conv2d(32, 32, 5), Tensor.relu, |
| 120 | nn.BatchNorm(32), Tensor.max_pool2d, |
| 121 | nn.Conv2d(32, 64, 3), Tensor.relu, |
| 122 | nn.Conv2d(64, 64, 3), Tensor.relu, |
| 123 | nn.BatchNorm(64), Tensor.max_pool2d, |
| 124 | lambda x: x.flatten(1), nn.Linear(576, 10)] |
| 125 | |
| 126 | # replace random weights with ones |
| 127 | Tensor.realize(*[p.replace(Tensor.ones_like(p).contiguous()) for p in nn.state.get_parameters(layers)]) |
| 128 | |
| 129 | # run model inference |
| 130 | probs = Tensor.rand(1, 1, 28, 28).sequential(layers).tolist() |
| 131 | self.assertEqual(len(probs[0]), 10) |
| 132 | |
| 133 | # TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE |
| 134 | @unittest.skipIf(CI and Device.DEFAULT == "DSP", "failing because of make things that can't be images not images") |
no test coverage detected