| 391 | lr_sched.step() |
| 392 | |
| 393 | def test_embedding(self): |
| 394 | B, T, embed_size, vocab_size = 4, 10, 20, 28 |
| 395 | |
| 396 | layer = nn.Embedding(vocab_size, embed_size) |
| 397 | x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32)) |
| 398 | z = layer(x) |
| 399 | |
| 400 | layer_sharded = nn.Embedding(vocab_size, embed_size) |
| 401 | layer_sharded.weight.replace(layer.weight.shard(devices_2, axis=1)).realize() |
| 402 | x_sharded = x.shard(devices_2, axis=None) |
| 403 | z_shard = layer_sharded(x_sharded) |
| 404 | |
| 405 | np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6) |
| 406 | |
| 407 | def test_embedding_backward(self, shard_weight_axis=None): |
| 408 | B, T, embed_size, vocab_size = 4, 10, 20, 28 |