(self)
| 32 | |
| 33 | @unittest.skip("incorrect use of transformer") |
| 34 | def test_small_transformer(self): |
| 35 | args_tiny = {"dim": 16, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 10} |
| 36 | model = Transformer(**args_tiny) |
| 37 | for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype).realize()) |
| 38 | # NOTE: you have to do this twice due to the k-v cache |
| 39 | for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize() |
| 40 | for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize() |
| 41 | Device[Device.DEFAULT].compiler.compile_cached = None |
| 42 | for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize() |
| 43 | |
| 44 | if __name__ == '__main__': |
| 45 | unittest.main() |
nothing calls this directly
no test coverage detected