(t:Tensor)
| 26 | # shard_w is "model parallel" |
| 27 | |
| 28 | def _test_allreduce(t:Tensor): |
| 29 | aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize() |
| 30 | ts = t.shard(devices_4, 0).realize() |
| 31 | b = Tensor(UOp.allreduce(ts.uop, Ops.ADD, ts.device)) |
| 32 | b.realize() |
| 33 | return aa, b |
| 34 | |
| 35 | @unittest.skipIf(not_support_multi_device(), "no multi") |
| 36 | class TestMultiTensor(unittest.TestCase): |
no test coverage detected
searching dependent graphs…