| 83 | |
| 84 | |
| 85 | def build_transform(is_train, args): |
| 86 | resize_im = args.input_size > 32 |
| 87 | if is_train: |
| 88 | # this should always dispatch to transforms_imagenet_train |
| 89 | transform = create_transform( |
| 90 | input_size=args.input_size, |
| 91 | is_training=True, |
| 92 | color_jitter=args.color_jitter, |
| 93 | auto_augment=args.aa, |
| 94 | interpolation=args.train_interpolation, |
| 95 | re_prob=args.reprob, |
| 96 | re_mode=args.remode, |
| 97 | re_count=args.recount, |
| 98 | ) |
| 99 | if not resize_im: |
| 100 | # replace RandomResizedCropAndInterpolation with |
| 101 | # RandomCrop |
| 102 | transform.transforms[0] = transforms.RandomCrop( |
| 103 | args.input_size, padding=4) |
| 104 | return transform |
| 105 | |
| 106 | t = [] |
| 107 | if resize_im: |
| 108 | size = int((256 / 224) * args.input_size) |
| 109 | t.append( |
| 110 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images |
| 111 | ) |
| 112 | t.append(transforms.CenterCrop(args.input_size)) |
| 113 | |
| 114 | t.append(transforms.ToTensor()) |
| 115 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) |
| 116 | return transforms.Compose(t) |