(self)
| 426 | |
| 427 | class TestSamplingTransform(unittest.TestCase): |
| 428 | def test_tokendrop(self): |
| 429 | tokendrop_cls = get_transforms_cls(["tokendrop"])["tokendrop"] |
| 430 | opt = Namespace(seed=3434, tokendrop_temperature=0.1) |
| 431 | tokendrop_transform = tokendrop_cls(opt) |
| 432 | tokendrop_transform.warm_up() |
| 433 | ex = { |
| 434 | "src": ["Hello", ",", "world", "."], |
| 435 | "tgt": ["Bonjour", "le", "monde", "."], |
| 436 | } |
| 437 | # Not apply token drop for not training example |
| 438 | ex_after = tokendrop_transform.apply(copy.deepcopy(ex), is_train=False) |
| 439 | self.assertEqual(ex_after, ex) |
| 440 | # apply token drop for training example |
| 441 | ex_after = tokendrop_transform.apply(copy.deepcopy(ex), is_train=True) |
| 442 | self.assertNotEqual(ex_after, ex) |
| 443 | |
| 444 | def test_tokenmask(self): |
| 445 | tokenmask_cls = get_transforms_cls(["tokenmask"])["tokenmask"] |
nothing calls this directly
no test coverage detected