MCPcopy
hub / github.com/microsoft/Cream / ModelTest

Class ModelTest

CDARTS/lib/models/model_test.py:4–166  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

2from lib.models.augment_cells import AugmentCell
3
4class ModelTest(nn.Module):
5
6 def __init__(self, genotypes_dict, model_type, res_stem=False, init_channel=96, stem_multiplier=3, n_nodes=4, num_classes=1000):
7 """
8 args:
9
10 """
11 super(ModelTest, self).__init__()
12 self.c_in = 3
13 self.init_channel = init_channel
14 self.stem_multiplier = stem_multiplier
15 self.num_classes = num_classes
16 self.n_nodes = n_nodes
17 self.model_type = model_type
18 self.res_stem = res_stem
19
20 if self.model_type == 'cifar':
21 reduction_p = False
22 self.layers_reduction = [True, True, False]
23 self.augment_layers = [7, 7, 6]
24 self.nas_layers = nn.ModuleList([None, None, None])
25 self.feature_extractor = self.cifar_stem(self.init_channel * self.stem_multiplier)
26
27 elif self.model_type == 'imagenet':
28 if self.res_stem:
29 reduction_p = False
30 self.nas_layers = nn.ModuleList([None, None, None, None])
31 self.layers_reduction = [False, True, True, True]
32 self.augment_layers = [3, 4, 3, 4]
33 self.feature_extractor = self.resnet_stem(self.init_channel * self.stem_multiplier)
34 else:
35 reduction_p = True
36 self.nas_layers = nn.ModuleList([None, None, None])
37 self.layers_reduction = [True, True, False]
38 self.augment_layers = [5, 5, 4]
39 self.feature_extractor = self.imagenet_stem(self.init_channel * self.stem_multiplier)
40 else:
41 raise Exception("Wrong model type!")
42
43 self.nas_layers_num = len(self.nas_layers)
44 c_p = self.init_channel * self.stem_multiplier
45 c_pp = self.init_channel * self.stem_multiplier
46 c_cur = self.init_channel
47
48 for layer_idx, genotype in genotypes_dict.items():
49 reduction = self.layers_reduction[layer_idx]
50 nas_layer = self.generate_nas_layer(c_cur, c_p, c_pp, reduction_p, reduction, genotype, self.augment_layers[layer_idx])
51 self.nas_layers[layer_idx] = nas_layer
52
53 if reduction:
54 c_p = c_cur * 2 * self.n_nodes
55 else:
56 c_p = c_cur * self.n_nodes
57
58 if self.res_stem:
59 c_pp = c_p
60 reduction_p = False
61 else:

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by 1

mainFunction · 0.72