MCPcopy
hub / github.com/zai-org/CogView / forward

Method forward

mpu/random.py:281–329  ·  view source on GitHub ↗
(ctx, run_function, *args)

Source from the content-addressed store, hash-verified

279 """
280 @staticmethod
281 def forward(ctx, run_function, *args):
282 ctx.run_function = run_function
283 global mp_rank, mp_size, mp_group
284 if mp_rank is None:
285 mp_rank = get_model_parallel_rank()
286 mp_size = get_model_parallel_world_size()
287 mp_group = get_model_parallel_group()
288
289
290 global cuda_device, transport_stream, PARTITION_ACTIVATIONS
291 if cuda_device is None:
292 if dist.get_rank() == 0:
293 print(f"Partition Activations {PARTITION_ACTIVATIONS} and Correctness Check {PA_CORRECTNESS_TEST}")
294
295 cuda_device = torch.cuda.current_device()
296 #The transport stream is used to overlap the allgather communication for the activations
297 #with the computation in the backward pass
298 transport_stream = torch.cuda.Stream(device=cuda_device)
299
300 if PARTITION_ACTIVATIONS:
301 inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
302 inputs.append(args[-1])
303
304 #just in case something funky is happening such as reuse of inputs
305 inputs_cuda = [item.to(cuda_device) for item in args]
306
307 # Copy the rng states.
308 ctx.fwd_cpu_rng_state = torch.get_rng_state()
309 ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
310 ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
311
312 #ctx.save_for_backward(*args)
313 with torch.no_grad():
314 outputs = run_function(*inputs_cuda)
315
316 del inputs_cuda
317
318 if PARTITION_ACTIVATIONS:
319 new_args = []
320 for arg, inp in zip(args,inputs):
321 size= torch.tensor(arg.size())
322 arg.data = inp.data
323 new_args.append(arg)
324 new_args.append(size)
325 ctx.save_for_backward(*new_args)
326 else:
327 ctx.save_for_backward(*args)
328
329 return outputs
330
331 @staticmethod
332 def backward(ctx, *args):

Callers

nothing calls this directly

Calls 7

get_model_parallel_rankFunction · 0.85
get_model_parallel_groupFunction · 0.85
get_partition_startFunction · 0.85
get_partition_sizeFunction · 0.85
get_cuda_rng_trackerFunction · 0.85
get_statesMethod · 0.80

Tested by

no test coverage detected