(self)
| 69 | self.assertEqual(x.shard(devices_2, 0).realize().shrink(((0, 2),)).tolist(), [0, 1]) |
| 70 | |
| 71 | def test_shard_like(self): |
| 72 | X = Tensor.ones(256).shard(devices_2, 0) |
| 73 | Y = Tensor.zeros(256).shard_like(X) |
| 74 | self.assertEqual(Y.device, X.device) |
| 75 | self.assertEqual(Y.uop.axis, 0) |
| 76 | # also test with axis=None |
| 77 | X2 = Tensor.ones(256).shard(devices_2, axis=None) |
| 78 | Y2 = Tensor.zeros(256).shard_like(X2) |
| 79 | self.assertEqual(Y2.device, X2.device) |
| 80 | self.assertEqual(Y2.uop.axis, None) |
| 81 | # test with single device |
| 82 | X3 = Tensor.ones(256) |
| 83 | Y3 = Tensor.zeros(256).shard_like(X3) |
| 84 | self.assertEqual(Y3.device, X3.device) |
| 85 | # cannot shard_like multi unless it's a no-op |
| 86 | X4 = Tensor.ones(256).shard(devices_2, 0) |
| 87 | Y4 = Tensor.ones(256).shard(devices_2, 0).shard_like(X4) |
| 88 | self.assertEqual(Y4.device, X4.device) |
| 89 | self.assertEqual(Y4.uop.axis, 0) |
| 90 | with self.assertRaises(RuntimeError): |
| 91 | Tensor.ones(256).shard(devices_2, None).shard_like(X4) |
| 92 | |
| 93 | def _test_shard_op(self, op, out, n=4): |
| 94 | t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0) |
nothing calls this directly
no test coverage detected