MCPcopy
hub / github.com/d2l-ai/d2l-zh / train_epoch_ch8

Function train_epoch_ch8

d2l/mxnet.py:697–720  ·  view source on GitHub ↗

Train a model within one epoch (defined in Chapter 8). Defined in :numref:`sec_rnn_scratch`

(net, train_iter, loss, updater, device, use_random_iter)

Source from the content-addressed store, hash-verified

695 param.grad[:] *= theta / norm
696
697def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):
698 """Train a model within one epoch (defined in Chapter 8).
699
700 Defined in :numref:`sec_rnn_scratch`"""
701 state, timer = None, d2l.Timer()
702 metric = d2l.Accumulator(2) # Sum of training loss, no. of tokens
703 for X, Y in train_iter:
704 if state is None or use_random_iter:
705 # Initialize `state` when either it is the first iteration or
706 # using random sampling
707 state = net.begin_state(batch_size=X.shape[0], ctx=device)
708 else:
709 for s in state:
710 s.detach()
711 y = Y.T.reshape(-1)
712 X, y = X.as_in_ctx(device), y.as_in_ctx(device)
713 with autograd.record():
714 y_hat, state = net(X, state)
715 l = loss(y_hat, y).mean()
716 l.backward()
717 grad_clipping(net, 1)
718 updater(batch_size=1) # Since the `mean` function has been invoked
719 metric.add(l * d2l.size(y), d2l.size(y))
720 return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()
721
722def train_ch8(net, train_iter, vocab, lr, num_epochs, device,
723 use_random_iter=False):

Callers 1

train_ch8Function · 0.70

Calls 4

addMethod · 0.95
grad_clippingFunction · 0.70
begin_stateMethod · 0.45
stopMethod · 0.45

Tested by

no test coverage detected