| 22 | import mpu |
| 23 | |
| 24 | class DataConfig: |
| 25 | |
| 26 | def __init__(self, defaults={}): |
| 27 | super(DataConfig, self).__init__() |
| 28 | self.defaults = defaults |
| 29 | |
| 30 | def apply(self, args): |
| 31 | if torch.distributed.get_rank() == 0: |
| 32 | print('configuring data') |
| 33 | self.apply_defaults(args) |
| 34 | return make_loaders(args) |
| 35 | |
| 36 | def set_defaults(self, **kwargs): |
| 37 | for k, v in kwargs.items(): |
| 38 | self.defaults[k] = v |
| 39 | |
| 40 | def apply_defaults(self, args): |
| 41 | for k, v in self.defaults.items(): |
| 42 | k = k.replace('-', '_') |
| 43 | if not hasattr(args, k): |
| 44 | setattr(args, k, v) |
| 45 | |
| 46 | |
| 47 | def make_data_loader(dataset, batch_size, args): |