Perform an allreduce operation on an array. Parameters ---------- array : DRef The array to be reduced. op : str = "sum" The reduce operation to be performed. Available options are: - "sum" - "prod" - "min"
(
self,
src: DRef,
dst: DRef,
op: str = "sum", # pylint: disable=invalid-name
in_group: bool = True,
)
| 482 | func(from_array, in_group, to_array) |
| 483 | |
| 484 | def allreduce( |
| 485 | self, |
| 486 | src: DRef, |
| 487 | dst: DRef, |
| 488 | op: str = "sum", # pylint: disable=invalid-name |
| 489 | in_group: bool = True, |
| 490 | ) -> DRef: |
| 491 | """Perform an allreduce operation on an array. |
| 492 | |
| 493 | Parameters |
| 494 | ---------- |
| 495 | array : DRef |
| 496 | The array to be reduced. |
| 497 | |
| 498 | op : str = "sum" |
| 499 | The reduce operation to be performed. Available options are: |
| 500 | - "sum" |
| 501 | - "prod" |
| 502 | - "min" |
| 503 | - "max" |
| 504 | - "avg" |
| 505 | |
| 506 | in_group : bool |
| 507 | Whether the reduce operation performs globally or in group as default. |
| 508 | """ |
| 509 | if op not in REDUCE_OPS: |
| 510 | raise ValueError(f"Unsupported reduce op: {op}. Available ops are: {REDUCE_OPS.keys()}") |
| 511 | op = Shape([REDUCE_OPS[op]]) |
| 512 | func = self._get_cached_method("runtime.disco.allreduce") |
| 513 | func(src, op, in_group, dst) |
| 514 | |
| 515 | def allgather( |
| 516 | self, |