| 386 | np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) |
| 387 | |
| 388 | def test_embedding(self): |
| 389 | B, T, embed_size, vocab_size = 4, 10, 20, 28 |
| 390 | |
| 391 | # create in tinygrad |
| 392 | layer = Embedding(vocab_size, embed_size) |
| 393 | |
| 394 | with torch.no_grad(): |
| 395 | torch_layer = torch.nn.Embedding(vocab_size, embed_size).eval() |
| 396 | torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) |
| 397 | |
| 398 | # test |
| 399 | x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32)) |
| 400 | z = layer(x) |
| 401 | torch_x = torch.tensor(x.numpy()) |
| 402 | torch_z = torch_layer(torch_x) |
| 403 | np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8) |
| 404 | |
| 405 | # test with empty input length |
| 406 | x = Tensor(np.random.randint(0, vocab_size, (B, 0), dtype=np.int32)) |
| 407 | z = layer(x) |
| 408 | torch_x = torch.tensor(x.numpy()) |
| 409 | torch_z = torch_layer(torch_x) |
| 410 | np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8) |
| 411 | |
| 412 | # test with jit enabled |
| 413 | @TinyJit |
| 414 | def layer_jit(x): |
| 415 | return layer(x).realize() |
| 416 | |
| 417 | for _ in range(3): |
| 418 | x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32)) |
| 419 | z = layer_jit(x) |
| 420 | torch_x = torch.tensor(x.numpy()) |
| 421 | torch_z = torch_layer(torch_x) |
| 422 | np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8) |
| 423 | |
| 424 | def test_embedding_one_kernel(self, ops=612000, kcount=2): |
| 425 | GlobalCounters.reset() |