函数作用:将数据集分割成 batch, 基于 mini batch 训练。
(X, batchsize=256, shuffle=True)
| 257 | |
| 258 | |
| 259 | def minibatch(X, batchsize=256, shuffle=True): |
| 260 | """ |
| 261 | 函数作用:将数据集分割成 batch, 基于 mini batch 训练。 |
| 262 | """ |
| 263 | N = X.shape[0] |
| 264 | idx = np.arange(N) |
| 265 | n_batches = int(np.ceil(N / batchsize)) |
| 266 | |
| 267 | if shuffle: |
| 268 | np.random.shuffle(idx) |
| 269 | |
| 270 | def mb_generator(): |
| 271 | for i in range(n_batches): |
| 272 | yield idx[i * batchsize : (i + 1) * batchsize] |
| 273 | |
| 274 | return mb_generator(), n_batches |
| 275 | |
| 276 | |
| 277 | class DFN(object): |