MCPcopy
hub / github.com/microsoft/Cream / test_finetune

Method test_finetune

TinyViT/tests/test_models.py:38–53  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

36 assert model.head.weight.shape[0] == pretrained_num_classes
37
38 def test_finetune(self):
39 pretrained_num_classes = 1000
40 finetune_num_classes = 100
41 model1 = timm.create_model('tiny_vit_5m_224', pretrained=True, pretrained_type='22kto1k_distill')
42 model2 = timm.create_model('tiny_vit_5m_224', pretrained=True, pretrained_type='22kto1k_distill',
43 num_classes=finetune_num_classes)
44 state_dict_1 = model1.state_dict()
45 state_dict_2 = model2.state_dict()
46 keys = list(state_dict_1.keys())
47 head_keys = ['head.weight', 'head.bias']
48 for name in head_keys:
49 self.assertEqual(state_dict_1.pop(name).shape[0], pretrained_num_classes)
50 self.assertEqual(state_dict_2.pop(name).shape[0], finetune_num_classes)
51 for key in keys:
52 if key not in head_keys:
53 self.assertTrue(torch.equal(state_dict_1[key], state_dict_2[key]))
54
55 def test_forward(self):
56 for variant, _ in self.ckpt_names:

Callers

nothing calls this directly

Calls 1

state_dictMethod · 0.45

Tested by

no test coverage detected