MCPcopy
hub / github.com/PaddlePaddle/ERNIE / sequence_parallel_sparse_mask_labels

Function sequence_parallel_sparse_mask_labels

ernie/sequence_parallel_utils.py:291–306  ·  view source on GitHub ↗

allgather sparse label and return sparse idx

(labels, ignore_label=-100)

Source from the content-addressed store, hash-verified

289
290
291def sequence_parallel_sparse_mask_labels(labels, ignore_label=-100):
292 """allgather sparse label and return sparse idx"""
293 hcg = fleet.get_hybrid_communicate_group()
294 group = hcg.get_model_parallel_group()
295 parallelism = group.nranks
296 labels = labels.flatten()
297 labels_local = paddle.split(labels, group.nranks)[group.rank]
298
299 tgt_index = paddle.nonzero(labels_local != ignore_label).squeeze()
300 if tgt_index.numel() == 0:
301 tgt_index = paddle.to_tensor([0])
302
303 tgt_index = tgt_index.reshape([-1]).astype(paddle.int32)
304 labels_local_gather = paddle.take_along_axis(labels_local, tgt_index, axis=0)
305 labels_all_gather = AllGatherVarlenOp.apply(labels_local_gather)
306 return labels_all_gather, tgt_index.reshape([-1, 1])
307
308
309###################################################

Callers 1

forwardMethod · 0.70

Calls 1

flattenMethod · 0.80

Tested by

no test coverage detected