(x, loop_num=inner_loop)
| 66 | allreduce = AllReduce(mapping=mapping, strategy=strategy) |
| 67 | |
| 68 | def func(x, loop_num=inner_loop): |
| 69 | for _ in range(loop_num): |
| 70 | output = allreduce(x, all_reduce_params=allreduce_params) |
| 71 | return output if fusion == AllReduceFusionOp.NONE else output[0] |
| 72 | |
| 73 | start = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)] |
| 74 | stop = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)] |