MCPcopy
hub / github.com/facebookresearch/dlrm / forward

Method forward

extend_distributed.py:446–465  ·  view source on GitHub ↗
(ctx, *output)

Source from the content-addressed store, hash-verified

444class All2All_Wait(Function):
445 @staticmethod
446 def forward(ctx, *output):
447 global myreq
448 with record_function("DLRM alltoall_wait_fwd_single"):
449 a2a_info = myreq.a2a_info
450 ctx.a2a_info = a2a_info
451 myreq.req.wait()
452 myreq.req = None
453 myreq.tensor = None
454 table_split_lengths = (
455 a2a_info.table_split_lengths
456 if a2a_info.table_split_lengths
457 else a2a_info.local_table_num
458 * a2a_info.local_batch_num
459 * a2a_info.emb_dim
460 )
461 outputs = output[0].split(table_split_lengths)
462 outputs = tuple(
463 [out.view([a2a_info.local_batch_num, -1]) for out in outputs]
464 )
465 return outputs
466
467 @staticmethod
468 def backward(ctx, *grad_outputs):

Callers

nothing calls this directly

Calls 1

waitMethod · 0.80

Tested by

no test coverage detected