MCPcopy
hub / github.com/tinygrad/tinygrad / test_mnist

Method test_mnist

test/test_tiny.py:116–131  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

114 # TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE
115 @unittest.skipIf(CI and Device.DEFAULT == "DSP", "failing because of make things that can't be images not images")
116 def test_mnist(self):
117 layers = [
118 nn.Conv2d(1, 32, 5), Tensor.relu,
119 nn.Conv2d(32, 32, 5), Tensor.relu,
120 nn.BatchNorm(32), Tensor.max_pool2d,
121 nn.Conv2d(32, 64, 3), Tensor.relu,
122 nn.Conv2d(64, 64, 3), Tensor.relu,
123 nn.BatchNorm(64), Tensor.max_pool2d,
124 lambda x: x.flatten(1), nn.Linear(576, 10)]
125
126 # replace random weights with ones
127 Tensor.realize(*[p.replace(Tensor.ones_like(p).contiguous()) for p in nn.state.get_parameters(layers)])
128
129 # run model inference
130 probs = Tensor.rand(1, 1, 28, 28).sequential(layers).tolist()
131 self.assertEqual(len(probs[0]), 10)
132
133 # TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE
134 @unittest.skipIf(CI and Device.DEFAULT == "DSP", "failing because of make things that can't be images not images")

Callers 1

mnistFunction · 0.80

Calls 9

flattenMethod · 0.80
realizeMethod · 0.80
ones_likeMethod · 0.80
sequentialMethod · 0.80
randMethod · 0.80
replaceMethod · 0.45
contiguousMethod · 0.45
get_parametersMethod · 0.45
tolistMethod · 0.45

Tested by

no test coverage detected