MCPcopy
hub / github.com/tinygrad/tinygrad / test_shard_like

Method test_shard_like

test/backend/test_multitensor.py:71–91  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

69 self.assertEqual(x.shard(devices_2, 0).realize().shrink(((0, 2),)).tolist(), [0, 1])
70
71 def test_shard_like(self):
72 X = Tensor.ones(256).shard(devices_2, 0)
73 Y = Tensor.zeros(256).shard_like(X)
74 self.assertEqual(Y.device, X.device)
75 self.assertEqual(Y.uop.axis, 0)
76 # also test with axis=None
77 X2 = Tensor.ones(256).shard(devices_2, axis=None)
78 Y2 = Tensor.zeros(256).shard_like(X2)
79 self.assertEqual(Y2.device, X2.device)
80 self.assertEqual(Y2.uop.axis, None)
81 # test with single device
82 X3 = Tensor.ones(256)
83 Y3 = Tensor.zeros(256).shard_like(X3)
84 self.assertEqual(Y3.device, X3.device)
85 # cannot shard_like multi unless it's a no-op
86 X4 = Tensor.ones(256).shard(devices_2, 0)
87 Y4 = Tensor.ones(256).shard(devices_2, 0).shard_like(X4)
88 self.assertEqual(Y4.device, X4.device)
89 self.assertEqual(Y4.uop.axis, 0)
90 with self.assertRaises(RuntimeError):
91 Tensor.ones(256).shard(devices_2, None).shard_like(X4)
92
93 def _test_shard_op(self, op, out, n=4):
94 t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0)

Callers

nothing calls this directly

Calls 4

shard_likeMethod · 0.80
zerosMethod · 0.80
shardMethod · 0.45
onesMethod · 0.45

Tested by

no test coverage detected