MCPcopy
hub / github.com/d2l-ai/d2l-zh / train_ch11

Function train_ch11

d2l/mxnet.py:1316–1342  ·  view source on GitHub ↗

Defined in :numref:`sec_minibatches`

(trainer_fn, states, hyperparams, data_iter,
               feature_dim, num_epochs=2)

Source from the content-addressed store, hash-verified

1314 return data_iter, data.shape[1]-1
1315
1316def train_ch11(trainer_fn, states, hyperparams, data_iter,
1317 feature_dim, num_epochs=2):
1318 """Defined in :numref:`sec_minibatches`"""
1319 # Initialization
1320 w = np.random.normal(scale=0.01, size=(feature_dim, 1))
1321 b = np.zeros(1)
1322 w.attach_grad()
1323 b.attach_grad()
1324 net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
1325 # Train
1326 animator = d2l.Animator(xlabel='epoch', ylabel='loss',
1327 xlim=[0, num_epochs], ylim=[0.22, 0.35])
1328 n, timer = 0, d2l.Timer()
1329 for _ in range(num_epochs):
1330 for X, y in data_iter:
1331 with autograd.record():
1332 l = loss(net(X), y).mean()
1333 l.backward()
1334 trainer_fn([w, b], states, hyperparams)
1335 n += X.shape[0]
1336 if n % 200 == 0:
1337 timer.stop()
1338 animator.add(n/X.shape[0]/len(data_iter),
1339 (d2l.evaluate_loss(net, data_iter, loss),))
1340 timer.start()
1341 print(f'loss: {animator.Y[0][-1]:.3f}, {timer.avg():.3f} sec/epoch')
1342 return timer.cumsum(), animator.Y[0]
1343
1344def train_concise_ch11(tr_name, hyperparams, data_iter, num_epochs=2):
1345 """Defined in :numref:`sec_minibatches`"""

Callers

nothing calls this directly

Calls 5

addMethod · 0.95
stopMethod · 0.45
startMethod · 0.45
avgMethod · 0.45
cumsumMethod · 0.45

Tested by

no test coverage detected