MCPcopy
hub / github.com/tinygrad/tinygrad / _embedding_bwd_kernel

Function _embedding_bwd_kernel

tinygrad/nn/__init__.py:333–362  ·  view source on GitHub ↗
(grad_weight:UOp, grad_emb:UOp, idx:UOp)

Source from the content-addressed store, hash-verified

331
332 # this is the real atomic kernel
333 def _embedding_bwd_kernel(grad_weight:UOp, grad_emb:UOp, idx:UOp) -> UOp:
334 idx_flat, grad_emb_flat = idx.flatten(), grad_emb.reshape((idx.numel(), grad_weight.shape[-1]))
335
336 embed_size = grad_weight.shape[-1]
337 BLOCK_J = min(256, embed_size)
338 assert embed_size % BLOCK_J == 0, f"embed_size {embed_size} must be divisible by {BLOCK_J}"
339
340 n_j_blocks = embed_size // BLOCK_J
341 i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length -> GLOBAL
342 j_inner = UOp.range(BLOCK_J, 2, AxisType.LOOP if device in ("CPU", "NULL") else AxisType.LOCAL) # BLOCK_J threads per workgroup
343 j_outer = UOp.range(n_j_blocks, 1)
344 j = j_outer * BLOCK_J + j_inner
345
346 if is_vocab_sharded:
347 # each device owns [offset, offset+local_vocab_size) of the global vocabulary
348 dnum = UOp.variable("_device_num", 0, ndev-1)
349 offset = dnum * local_vocab_size
350 global_token_id = idx_flat[i].cast(dtypes.weakint)
351 local_token_id = (global_token_id - offset).clip(0, grad_weight.shape[0]-1)
352 in_range = (global_token_id >= offset) & (global_token_id < (offset + local_vocab_size))
353 grad_val = in_range.where(grad_emb_flat[i, j].cast(dtypes.float), 0.0)
354 else:
355 local_token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.weakint)
356 grad_val = grad_emb_flat[i, j].cast(dtypes.float)
357 # atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j]
358 if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);"
359 elif device == "AMD": atomic_arg = "__hip_atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);"
360 else: raise NotImplementedError(f"no atomics for device {device}")
361 atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(local_token_id, j, ptr=True), grad_val), arg = atomic_arg)
362 return atomic.end(i, j_outer, j_inner).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=()))
363
364 grad_weight_uop = grad_weight_uop.custom_kernel(grad_emb, idx, fxn=_embedding_bwd_kernel)[0]
365

Callers

nothing calls this directly

Calls 13

endMethod · 0.95
UOpClass · 0.90
KernelInfoClass · 0.90
flattenMethod · 0.80
reshapeMethod · 0.80
numelMethod · 0.80
variableMethod · 0.80
clipMethod · 0.80
rangeMethod · 0.45
castMethod · 0.45
whereMethod · 0.45
indexMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…