| 405 | np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6) |
| 406 | |
| 407 | def test_embedding_backward(self, shard_weight_axis=None): |
| 408 | B, T, embed_size, vocab_size = 4, 10, 20, 28 |
| 409 | |
| 410 | layer = nn.Embedding(vocab_size, embed_size) |
| 411 | x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32)) |
| 412 | z = layer(x) |
| 413 | z.sum().backward() |
| 414 | grad = layer.weight.grad.numpy() |
| 415 | |
| 416 | layer_sharded = nn.Embedding(vocab_size, embed_size) |
| 417 | layer_sharded.weight.replace(layer.weight.shard(devices_2, axis=shard_weight_axis)).realize() |
| 418 | x_sharded = x.shard(devices_2, axis=None) |
| 419 | z_shard = layer_sharded(x_sharded) |
| 420 | z_shard.sum().backward() |
| 421 | grad_shard = layer_sharded.weight.grad.numpy() |
| 422 | |
| 423 | np.testing.assert_allclose(grad, grad_shard, atol=1e-6, rtol=1e-6) |
| 424 | |
| 425 | def test_embedding_backward_shard_weight(self): self.test_embedding_backward(shard_weight_axis=1) |
| 426 | |