| 53 | |
| 54 | |
| 55 | class TextSamplerDataset(Dataset): |
| 56 | def __init__(self, data, seq_len): |
| 57 | super().__init__() |
| 58 | self.data = data |
| 59 | self.seq_len = seq_len |
| 60 | |
| 61 | def __getitem__(self, index): |
| 62 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) |
| 63 | full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long() |
| 64 | return full_seq |
| 65 | |
| 66 | def __len__(self): |
| 67 | return self.data.size(0) // self.seq_len |
| 68 | |
| 69 | |
| 70 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) |