Generates a batch iterator for a dataset.
(data, batch_size, num_epochs, shuffle=True)
| 44 | |
| 45 | |
| 46 | def batch_iter(data, batch_size, num_epochs, shuffle=True): |
| 47 | """ |
| 48 | Generates a batch iterator for a dataset. |
| 49 | """ |
| 50 | data = np.array(data) |
| 51 | data_size = len(data) |
| 52 | num_batches_per_epoch = int((len(data)-1)/batch_size) + 1 |
| 53 | for epoch in range(num_epochs): |
| 54 | # Shuffle the data at each epoch |
| 55 | if shuffle: |
| 56 | shuffle_indices = np.random.permutation(np.arange(data_size)) |
| 57 | shuffled_data = data[shuffle_indices] |
| 58 | else: |
| 59 | shuffled_data = data |
| 60 | for batch_num in range(num_batches_per_epoch): |
| 61 | start_index = batch_num * batch_size |
| 62 | end_index = min((batch_num + 1) * batch_size, data_size) |
| 63 | yield shuffled_data[start_index:end_index] |
nothing calls this directly
no outgoing calls
no test coverage detected