(src, num_shards, tgt)
| 56 | |
| 57 | @register_global_func("tests.disco.shard_dim_1", override=True) |
| 58 | def _shard_dim_1(src, num_shards, tgt): |
| 59 | s_0, s_1 = src.shape |
| 60 | tgt.copyfrom(src.numpy().reshape(s_0, num_shards, s_1 // num_shards).transpose(1, 0, 2)) |
| 61 | |
| 62 | |
| 63 | @register_global_func("tests.disco.shard_qkv_0", override=True) |