(src, num_shards, tgt)
| 50 | |
| 51 | @register_global_func("tests.disco.shard_dim_0", override=True) |
| 52 | def _shard_dim_0(src, num_shards, tgt): |
| 53 | s_0, s_1 = src.shape |
| 54 | tgt.copyfrom(src.numpy().reshape(num_shards, s_0 // num_shards, s_1)) |
| 55 | |
| 56 | |
| 57 | @register_global_func("tests.disco.shard_dim_1", override=True) |