(grad_weight:UOp, grad_emb:UOp, idx:UOp)
| 331 | |
| 332 | # this is the real atomic kernel |
| 333 | def _embedding_bwd_kernel(grad_weight:UOp, grad_emb:UOp, idx:UOp) -> UOp: |
| 334 | idx_flat, grad_emb_flat = idx.flatten(), grad_emb.reshape((idx.numel(), grad_weight.shape[-1])) |
| 335 | |
| 336 | embed_size = grad_weight.shape[-1] |
| 337 | BLOCK_J = min(256, embed_size) |
| 338 | assert embed_size % BLOCK_J == 0, f"embed_size {embed_size} must be divisible by {BLOCK_J}" |
| 339 | |
| 340 | n_j_blocks = embed_size // BLOCK_J |
| 341 | i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length -> GLOBAL |
| 342 | j_inner = UOp.range(BLOCK_J, 2, AxisType.LOOP if device in ("CPU", "NULL") else AxisType.LOCAL) # BLOCK_J threads per workgroup |
| 343 | j_outer = UOp.range(n_j_blocks, 1) |
| 344 | j = j_outer * BLOCK_J + j_inner |
| 345 | |
| 346 | if is_vocab_sharded: |
| 347 | # each device owns [offset, offset+local_vocab_size) of the global vocabulary |
| 348 | dnum = UOp.variable("_device_num", 0, ndev-1) |
| 349 | offset = dnum * local_vocab_size |
| 350 | global_token_id = idx_flat[i].cast(dtypes.weakint) |
| 351 | local_token_id = (global_token_id - offset).clip(0, grad_weight.shape[0]-1) |
| 352 | in_range = (global_token_id >= offset) & (global_token_id < (offset + local_vocab_size)) |
| 353 | grad_val = in_range.where(grad_emb_flat[i, j].cast(dtypes.float), 0.0) |
| 354 | else: |
| 355 | local_token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.weakint) |
| 356 | grad_val = grad_emb_flat[i, j].cast(dtypes.float) |
| 357 | # atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j] |
| 358 | if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);" |
| 359 | elif device == "AMD": atomic_arg = "__hip_atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);" |
| 360 | else: raise NotImplementedError(f"no atomics for device {device}") |
| 361 | atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(local_token_id, j, ptr=True), grad_val), arg = atomic_arg) |
| 362 | return atomic.end(i, j_outer, j_inner).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=())) |
| 363 | |
| 364 | grad_weight_uop = grad_weight_uop.custom_kernel(grad_emb, idx, fxn=_embedding_bwd_kernel)[0] |
| 365 |
nothing calls this directly
no test coverage detected
searching dependent graphs…