(self)
| 22 | class EmbeddingTest(absltest.TestCase): |
| 23 | |
| 24 | def test_embeddings(self): |
| 25 | table = pz.nn.EmbeddingTable.from_config( |
| 26 | name="embeddings", |
| 27 | init_base_rng=jax.random.key(2), |
| 28 | vocab_size=17, |
| 29 | embedding_axes={"foo": 5, "bar": 7}, |
| 30 | vocabulary_axis="vocabulary", |
| 31 | ) |
| 32 | |
| 33 | with self.subTest("lookup"): |
| 34 | lookup_layer = pz.nn.EmbeddingLookup(table) |
| 35 | result = lookup_layer(pz.nx.arange("baz", 3)) |
| 36 | pz.chk.check_structure( |
| 37 | result, pz.chk.ArraySpec(named_shape={"foo": 5, "bar": 7, "baz": 3}) |
| 38 | ) |
| 39 | |
| 40 | with self.subTest("decode"): |
| 41 | decode_layer = pz.nn.EmbeddingDecode(table) |
| 42 | result = decode_layer(pz.nx.ones({"foo": 5, "bar": 7, "baz": 3})) |
| 43 | pz.chk.check_structure( |
| 44 | result, pz.chk.ArraySpec(named_shape={"baz": 3, "vocabulary": 17}) |
| 45 | ) |
| 46 | |
| 47 | |
| 48 | if __name__ == "__main__": |
nothing calls this directly
no test coverage detected