(self, image)
| 112 | weight_decay = 4e-5 |
| 113 | |
| 114 | def get_logits(self, image): |
| 115 | |
| 116 | with argscope([Conv2D, MaxPooling, AvgPooling, GlobalAvgPooling, BatchNorm], data_format='channels_first'), \ |
| 117 | argscope(Conv2D, use_bias=False): |
| 118 | |
| 119 | group = args.group |
| 120 | if not args.v2: |
| 121 | # Copied from the paper |
| 122 | channels = { |
| 123 | 3: [240, 480, 960], |
| 124 | 4: [272, 544, 1088], |
| 125 | 8: [384, 768, 1536] |
| 126 | } |
| 127 | mul = group * 4 # #chan has to be a multiple of this number |
| 128 | channels = [int(math.ceil(x * args.ratio / mul) * mul) |
| 129 | for x in channels[group]] |
| 130 | # The first channel must be a multiple of group |
| 131 | first_chan = int(math.ceil(24 * args.ratio / group) * group) |
| 132 | else: |
| 133 | # Copied from the paper |
| 134 | channels = { |
| 135 | 0.5: [48, 96, 192], |
| 136 | 1.: [116, 232, 464] |
| 137 | }[args.ratio] |
| 138 | first_chan = 24 |
| 139 | |
| 140 | logger.info("#Channels: " + str([first_chan] + channels)) |
| 141 | |
| 142 | l = Conv2D('conv1', image, first_chan, 3, strides=2, activation=BNReLU) |
| 143 | l = MaxPooling('pool1', l, 3, 2, padding='SAME') |
| 144 | |
| 145 | l = shufflenet_stage('stage2', l, channels[0], 4, group) |
| 146 | l = shufflenet_stage('stage3', l, channels[1], 8, group) |
| 147 | l = shufflenet_stage('stage4', l, channels[2], 4, group) |
| 148 | |
| 149 | if args.v2: |
| 150 | l = Conv2D('conv5', l, 1024, 1, activation=BNReLU) |
| 151 | |
| 152 | l = GlobalAvgPooling('gap', l) |
| 153 | logits = FullyConnected('linear', l, 1000) |
| 154 | return logits |
| 155 | |
| 156 | |
| 157 | def get_data(name, batch): |
nothing calls this directly
no test coverage detected