allgather sparse label and return sparse idx
(labels, ignore_label=-100)
| 289 | |
| 290 | |
| 291 | def sequence_parallel_sparse_mask_labels(labels, ignore_label=-100): |
| 292 | """allgather sparse label and return sparse idx""" |
| 293 | hcg = fleet.get_hybrid_communicate_group() |
| 294 | group = hcg.get_model_parallel_group() |
| 295 | parallelism = group.nranks |
| 296 | labels = labels.flatten() |
| 297 | labels_local = paddle.split(labels, group.nranks)[group.rank] |
| 298 | |
| 299 | tgt_index = paddle.nonzero(labels_local != ignore_label).squeeze() |
| 300 | if tgt_index.numel() == 0: |
| 301 | tgt_index = paddle.to_tensor([0]) |
| 302 | |
| 303 | tgt_index = tgt_index.reshape([-1]).astype(paddle.int32) |
| 304 | labels_local_gather = paddle.take_along_axis(labels_local, tgt_index, axis=0) |
| 305 | labels_all_gather = AllGatherVarlenOp.apply(labels_local_gather) |
| 306 | return labels_all_gather, tgt_index.reshape([-1, 1]) |
| 307 | |
| 308 | |
| 309 | ################################################### |