MCPcopy
hub / github.com/tinygrad/tinygrad / test_embedding_backward

Method test_embedding_backward

test/backend/test_multitensor.py:407–423  ·  view source on GitHub ↗
(self, shard_weight_axis=None)

Source from the content-addressed store, hash-verified

405 np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6)
406
407 def test_embedding_backward(self, shard_weight_axis=None):
408 B, T, embed_size, vocab_size = 4, 10, 20, 28
409
410 layer = nn.Embedding(vocab_size, embed_size)
411 x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32))
412 z = layer(x)
413 z.sum().backward()
414 grad = layer.weight.grad.numpy()
415
416 layer_sharded = nn.Embedding(vocab_size, embed_size)
417 layer_sharded.weight.replace(layer.weight.shard(devices_2, axis=shard_weight_axis)).realize()
418 x_sharded = x.shard(devices_2, axis=None)
419 z_shard = layer_sharded(x_sharded)
420 z_shard.sum().backward()
421 grad_shard = layer_sharded.weight.grad.numpy()
422
423 np.testing.assert_allclose(grad, grad_shard, atol=1e-6, rtol=1e-6)
424
425 def test_embedding_backward_shard_weight(self): self.test_embedding_backward(shard_weight_axis=1)
426

Calls 9

shardMethod · 0.95
TensorClass · 0.90
randintMethod · 0.80
sumMethod · 0.80
realizeMethod · 0.80
backwardMethod · 0.45
numpyMethod · 0.45
replaceMethod · 0.45
shardMethod · 0.45

Tested by

no test coverage detected