| 55 | |
| 56 | |
| 57 | class PrefetchLoader: |
| 58 | |
| 59 | def __init__(self, |
| 60 | loader, |
| 61 | mean=IMAGENET_DEFAULT_MEAN, |
| 62 | std=IMAGENET_DEFAULT_STD, |
| 63 | fp16=False, |
| 64 | re_prob=0., |
| 65 | re_mode='const', |
| 66 | re_count=1, |
| 67 | re_num_splits=0): |
| 68 | self.loader = loader |
| 69 | self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) |
| 70 | self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) |
| 71 | self.fp16 = fp16 |
| 72 | if fp16: |
| 73 | self.mean = self.mean.half() |
| 74 | self.std = self.std.half() |
| 75 | if re_prob > 0.: |
| 76 | self.random_erasing = RandomErasing( |
| 77 | probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) |
| 78 | else: |
| 79 | self.random_erasing = None |
| 80 | |
| 81 | def __iter__(self): |
| 82 | stream = torch.cuda.Stream() |
| 83 | first = True |
| 84 | |
| 85 | for next_input, next_target in self.loader: |
| 86 | with torch.cuda.stream(stream): |
| 87 | next_input = next_input.cuda(non_blocking=True) |
| 88 | next_target = next_target.cuda(non_blocking=True) |
| 89 | if self.fp16: |
| 90 | next_input = next_input.half().sub_(self.mean).div_(self.std) |
| 91 | else: |
| 92 | next_input = next_input.float().sub_(self.mean).div_(self.std) |
| 93 | if self.random_erasing is not None: |
| 94 | next_input = self.random_erasing(next_input) |
| 95 | |
| 96 | if not first: |
| 97 | yield input, target |
| 98 | else: |
| 99 | first = False |
| 100 | |
| 101 | torch.cuda.current_stream().wait_stream(stream) |
| 102 | input = next_input |
| 103 | target = next_target |
| 104 | |
| 105 | yield input, target |
| 106 | |
| 107 | def __len__(self): |
| 108 | return len(self.loader) |
| 109 | |
| 110 | @property |
| 111 | def sampler(self): |
| 112 | return self.loader.sampler |
| 113 | |
| 114 | @property |