(is_train, config)
| 155 | |
| 156 | |
| 157 | def build_transform(is_train, config): |
| 158 | resize_im = config.DATA.IMG_SIZE > 32 |
| 159 | |
| 160 | # RGB: mean, std |
| 161 | rgbs = dict( |
| 162 | default=(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), |
| 163 | inception=(IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD), |
| 164 | clip=((0.48145466, 0.4578275, 0.40821073), |
| 165 | (0.26862954, 0.26130258, 0.27577711)), |
| 166 | ) |
| 167 | mean, std = rgbs[config.DATA.MEAN_AND_STD_TYPE] |
| 168 | |
| 169 | if is_train: |
| 170 | # this should always dispatch to transforms_imagenet_train |
| 171 | create_transform_t = create_transform if not config.DISTILL.ENABLED else create_transform_record |
| 172 | transform = create_transform_t( |
| 173 | input_size=config.DATA.IMG_SIZE, |
| 174 | is_training=True, |
| 175 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, |
| 176 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, |
| 177 | re_prob=config.AUG.REPROB, |
| 178 | re_mode=config.AUG.REMODE, |
| 179 | re_count=config.AUG.RECOUNT, |
| 180 | interpolation=config.DATA.INTERPOLATION, |
| 181 | mean=mean, |
| 182 | std=std, |
| 183 | ) |
| 184 | if not resize_im: |
| 185 | # replace RandomResizedCropAndInterpolation with |
| 186 | # RandomCrop |
| 187 | transform.transforms[0] = transforms.RandomCrop( |
| 188 | config.DATA.IMG_SIZE, padding=4) |
| 189 | |
| 190 | return transform |
| 191 | |
| 192 | t = [] |
| 193 | if resize_im: |
| 194 | if config.TEST.CROP: |
| 195 | size = int((256 / 224) * config.DATA.IMG_SIZE) |
| 196 | t.append( |
| 197 | transforms.Resize(size, interpolation=_pil_interp( |
| 198 | config.DATA.INTERPOLATION)), |
| 199 | # to maintain same ratio w.r.t. 224 images |
| 200 | ) |
| 201 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) |
| 202 | else: |
| 203 | t.append( |
| 204 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), |
| 205 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) |
| 206 | ) |
| 207 | |
| 208 | t.append(transforms.ToTensor()) |
| 209 | t.append(transforms.Normalize(mean, std)) |
| 210 | transform = transforms.Compose(t) |
| 211 | return transform |
no test coverage detected