(args = None)
| 86 | |
| 87 | |
| 88 | def new_data_aug_generator(args = None): |
| 89 | img_size = args.input_size |
| 90 | remove_random_resized_crop = False |
| 91 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |
| 92 | primary_tfl = [] |
| 93 | scale=(0.08, 1.0) |
| 94 | interpolation='bicubic' |
| 95 | if remove_random_resized_crop: |
| 96 | primary_tfl = [ |
| 97 | transforms.Resize(img_size, interpolation=3), |
| 98 | transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'), |
| 99 | transforms.RandomHorizontalFlip() |
| 100 | ] |
| 101 | else: |
| 102 | primary_tfl = [ |
| 103 | RandomResizedCropAndInterpolation( |
| 104 | img_size, scale=scale, interpolation=interpolation), |
| 105 | transforms.RandomHorizontalFlip() |
| 106 | ] |
| 107 | |
| 108 | |
| 109 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0), |
| 110 | Solarization(p=1.0), |
| 111 | GaussianBlur(p=1.0)])] |
| 112 | |
| 113 | if args.color_jitter is not None and not args.color_jitter==0: |
| 114 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) |
| 115 | final_tfl = [ |
| 116 | transforms.ToTensor(), |
| 117 | transforms.Normalize( |
| 118 | mean=torch.tensor(mean), |
| 119 | std=torch.tensor(std)) |
| 120 | ] |
| 121 | return transforms.Compose(primary_tfl+secondary_tfl+final_tfl) |
no test coverage detected