(Y, G, F, X=None)
| 40 | |
| 41 | |
| 42 | def broadcast(Y, G, F, X=None): |
| 43 | device = Y.device |
| 44 | if X is None: |
| 45 | X = torch.zeros( |
| 46 | G.shape + (Y.shape[-1],), |
| 47 | device=device, |
| 48 | dtype=Y.dtype |
| 49 | ) |
| 50 | |
| 51 | if device.type == "cpu": |
| 52 | broadcast_cpu(Y, G, F, X) |
| 53 | else: |
| 54 | broadcast_gpu(Y, G, F, X) |
| 55 | |
| 56 | return X |
| 57 | |
| 58 | |
| 59 | # Divide the cluster into groups of equal size |
no outgoing calls