(self)
| 36 | assert model.head.weight.shape[0] == pretrained_num_classes |
| 37 | |
| 38 | def test_finetune(self): |
| 39 | pretrained_num_classes = 1000 |
| 40 | finetune_num_classes = 100 |
| 41 | model1 = timm.create_model('tiny_vit_5m_224', pretrained=True, pretrained_type='22kto1k_distill') |
| 42 | model2 = timm.create_model('tiny_vit_5m_224', pretrained=True, pretrained_type='22kto1k_distill', |
| 43 | num_classes=finetune_num_classes) |
| 44 | state_dict_1 = model1.state_dict() |
| 45 | state_dict_2 = model2.state_dict() |
| 46 | keys = list(state_dict_1.keys()) |
| 47 | head_keys = ['head.weight', 'head.bias'] |
| 48 | for name in head_keys: |
| 49 | self.assertEqual(state_dict_1.pop(name).shape[0], pretrained_num_classes) |
| 50 | self.assertEqual(state_dict_2.pop(name).shape[0], finetune_num_classes) |
| 51 | for key in keys: |
| 52 | if key not in head_keys: |
| 53 | self.assertTrue(torch.equal(state_dict_1[key], state_dict_2[key])) |
| 54 | |
| 55 | def test_forward(self): |
| 56 | for variant, _ in self.ckpt_names: |
nothing calls this directly
no test coverage detected