| 80 | |
| 81 | |
| 82 | class GPT2Dataset(Dataset): |
| 83 | |
| 84 | def __init__(self, data_path, sizes_filename, seq_length, |
| 85 | initial_seed, max_epochs=100): |
| 86 | # Input parameters. |
| 87 | self.data_path = data_path |
| 88 | self.sizes_filename = sizes_filename |
| 89 | self.seq_length = seq_length |
| 90 | self.initial_seed = initial_seed |
| 91 | self.max_epochs = max_epochs |
| 92 | # Lock for building the dataset. |
| 93 | self.lock = Lock() |
| 94 | |
| 95 | # Shard stuff. |
| 96 | # Dictionary from shard nameto its size (number of element). |
| 97 | self.master_shard_size_dict = None |
| 98 | # Dictionary from shard name to modified size so it is |
| 99 | # divisible by self.seq_length. |
| 100 | self.shard_size_dict = None |
| 101 | # Long array (self.max_epochs * num-shards) populated |
| 102 | # randomly with shard names. |
| 103 | self.shards_name = None |
| 104 | # Start index of the data for a shard. |
| 105 | self.shards_start_index = None |
| 106 | self.build_shard_mappings_() |
| 107 | self.data_length = self.shards_start_index[-1] |
| 108 | |
| 109 | # Data. |
| 110 | self.shards_data = [None]*self.shards_name.size |
| 111 | self.shards_sample_index = [None]*self.shards_name.size |
| 112 | |
| 113 | def __len__(self): |
| 114 | return self.data_length |
| 115 | |
| 116 | def __getitem__(self, idx): |
| 117 | # Find which shard we need. |
| 118 | shard_index = np.searchsorted(self.shards_start_index, |
| 119 | idx, side='right') - 1 |
| 120 | # data index in the shard. |
| 121 | data_idx = idx - self.shards_start_index[shard_index] |
| 122 | # Load the shard if it is not in memory. |
| 123 | #self.lock.acquire() |
| 124 | if self.shards_data[shard_index] is None: |
| 125 | print('global rank {} is building data for shard index {} ...'. |
| 126 | format(torch.distributed.get_rank(), shard_index)) |
| 127 | self.build_dataset_(shard_index) |
| 128 | #assert self.shards_data[shard_index] is not None |
| 129 | #self.lock.release() |
| 130 | # Start index. |
| 131 | start_index = self.shards_sample_index[shard_index][data_idx] |
| 132 | # Add one for label shift. |
| 133 | end_index = start_index + self.seq_length + 1 |
| 134 | data = self.shards_data[shard_index][start_index:end_index] |
| 135 | return {'text': np.array(data, dtype=np.int64)} |
| 136 | |
| 137 | def build_dataset_(self, shard_index): |
| 138 | # Garbage collect so we don't use a lot of memory. |
| 139 | # Leave the last one in case other threads have not catche up yet. |