(ctx, run_function, *args)
| 279 | """ |
| 280 | @staticmethod |
| 281 | def forward(ctx, run_function, *args): |
| 282 | ctx.run_function = run_function |
| 283 | global mp_rank, mp_size, mp_group |
| 284 | if mp_rank is None: |
| 285 | mp_rank = get_model_parallel_rank() |
| 286 | mp_size = get_model_parallel_world_size() |
| 287 | mp_group = get_model_parallel_group() |
| 288 | |
| 289 | |
| 290 | global cuda_device, transport_stream, PARTITION_ACTIVATIONS |
| 291 | if cuda_device is None: |
| 292 | if dist.get_rank() == 0: |
| 293 | print(f"Partition Activations {PARTITION_ACTIVATIONS} and Correctness Check {PA_CORRECTNESS_TEST}") |
| 294 | |
| 295 | cuda_device = torch.cuda.current_device() |
| 296 | #The transport stream is used to overlap the allgather communication for the activations |
| 297 | #with the computation in the backward pass |
| 298 | transport_stream = torch.cuda.Stream(device=cuda_device) |
| 299 | |
| 300 | if PARTITION_ACTIVATIONS: |
| 301 | inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]] |
| 302 | inputs.append(args[-1]) |
| 303 | |
| 304 | #just in case something funky is happening such as reuse of inputs |
| 305 | inputs_cuda = [item.to(cuda_device) for item in args] |
| 306 | |
| 307 | # Copy the rng states. |
| 308 | ctx.fwd_cpu_rng_state = torch.get_rng_state() |
| 309 | ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() |
| 310 | ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() |
| 311 | |
| 312 | #ctx.save_for_backward(*args) |
| 313 | with torch.no_grad(): |
| 314 | outputs = run_function(*inputs_cuda) |
| 315 | |
| 316 | del inputs_cuda |
| 317 | |
| 318 | if PARTITION_ACTIVATIONS: |
| 319 | new_args = [] |
| 320 | for arg, inp in zip(args,inputs): |
| 321 | size= torch.tensor(arg.size()) |
| 322 | arg.data = inp.data |
| 323 | new_args.append(arg) |
| 324 | new_args.append(size) |
| 325 | ctx.save_for_backward(*new_args) |
| 326 | else: |
| 327 | ctx.save_for_backward(*args) |
| 328 | |
| 329 | return outputs |
| 330 | |
| 331 | @staticmethod |
| 332 | def backward(ctx, *args): |
nothing calls this directly
no test coverage detected