| 2 | from lib.models.augment_cells import AugmentCell |
| 3 | |
| 4 | class ModelTest(nn.Module): |
| 5 | |
| 6 | def __init__(self, genotypes_dict, model_type, res_stem=False, init_channel=96, stem_multiplier=3, n_nodes=4, num_classes=1000): |
| 7 | """ |
| 8 | args: |
| 9 | |
| 10 | """ |
| 11 | super(ModelTest, self).__init__() |
| 12 | self.c_in = 3 |
| 13 | self.init_channel = init_channel |
| 14 | self.stem_multiplier = stem_multiplier |
| 15 | self.num_classes = num_classes |
| 16 | self.n_nodes = n_nodes |
| 17 | self.model_type = model_type |
| 18 | self.res_stem = res_stem |
| 19 | |
| 20 | if self.model_type == 'cifar': |
| 21 | reduction_p = False |
| 22 | self.layers_reduction = [True, True, False] |
| 23 | self.augment_layers = [7, 7, 6] |
| 24 | self.nas_layers = nn.ModuleList([None, None, None]) |
| 25 | self.feature_extractor = self.cifar_stem(self.init_channel * self.stem_multiplier) |
| 26 | |
| 27 | elif self.model_type == 'imagenet': |
| 28 | if self.res_stem: |
| 29 | reduction_p = False |
| 30 | self.nas_layers = nn.ModuleList([None, None, None, None]) |
| 31 | self.layers_reduction = [False, True, True, True] |
| 32 | self.augment_layers = [3, 4, 3, 4] |
| 33 | self.feature_extractor = self.resnet_stem(self.init_channel * self.stem_multiplier) |
| 34 | else: |
| 35 | reduction_p = True |
| 36 | self.nas_layers = nn.ModuleList([None, None, None]) |
| 37 | self.layers_reduction = [True, True, False] |
| 38 | self.augment_layers = [5, 5, 4] |
| 39 | self.feature_extractor = self.imagenet_stem(self.init_channel * self.stem_multiplier) |
| 40 | else: |
| 41 | raise Exception("Wrong model type!") |
| 42 | |
| 43 | self.nas_layers_num = len(self.nas_layers) |
| 44 | c_p = self.init_channel * self.stem_multiplier |
| 45 | c_pp = self.init_channel * self.stem_multiplier |
| 46 | c_cur = self.init_channel |
| 47 | |
| 48 | for layer_idx, genotype in genotypes_dict.items(): |
| 49 | reduction = self.layers_reduction[layer_idx] |
| 50 | nas_layer = self.generate_nas_layer(c_cur, c_p, c_pp, reduction_p, reduction, genotype, self.augment_layers[layer_idx]) |
| 51 | self.nas_layers[layer_idx] = nas_layer |
| 52 | |
| 53 | if reduction: |
| 54 | c_p = c_cur * 2 * self.n_nodes |
| 55 | else: |
| 56 | c_p = c_cur * self.n_nodes |
| 57 | |
| 58 | if self.res_stem: |
| 59 | c_pp = c_p |
| 60 | reduction_p = False |
| 61 | else: |