MCPcopy
hub / github.com/d2l-ai/d2l-zh / train_ranking

Function train_ranking

d2l/mxnet.py:2709–2740  ·  view source on GitHub ↗
(net, train_iter, test_iter, loss, trainer, test_seq_iter,
                  num_users, num_items, num_epochs, devices, evaluator,
                  candidates, eval_step=1)

Source from the content-addressed store, hash-verified

2707 return np.mean(np.array(hit_rate)), np.mean(np.array(auc))
2708
2709def train_ranking(net, train_iter, test_iter, loss, trainer, test_seq_iter,
2710 num_users, num_items, num_epochs, devices, evaluator,
2711 candidates, eval_step=1):
2712 timer, hit_rate, auc = d2l.Timer(), 0, 0
2713 animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
2714 legend=['test hit rate', 'test AUC'])
2715 for epoch in range(num_epochs):
2716 metric, l = d2l.Accumulator(3), 0.
2717 for i, values in enumerate(train_iter):
2718 input_data = []
2719 for v in values:
2720 input_data.append(gluon.utils.split_and_load(v, devices))
2721 with autograd.record():
2722 p_pos = [net(*t) for t in zip(*input_data[0:-1])]
2723 p_neg = [net(*t) for t in zip(*input_data[0:-2],
2724 input_data[-1])]
2725 ls = [loss(p, n) for p, n in zip(p_pos, p_neg)]
2726 [l.backward(retain_graph=False) for l in ls]
2727 l += sum([l.asnumpy() for l in ls]).mean()/len(devices)
2728 trainer.step(values[0].shape[0])
2729 metric.add(l, values[0].shape[0], values[0].size)
2730 timer.stop()
2731 with autograd.predict_mode():
2732 if (epoch + 1) % eval_step == 0:
2733 hit_rate, auc = evaluator(net, test_iter, test_seq_iter,
2734 candidates, num_users, num_items,
2735 devices)
2736 animator.add(epoch + 1, (hit_rate, auc))
2737 print(f'train loss {metric[0] / metric[1]:.3f}, '
2738 f'test hit rate {float(hit_rate):.3f}, test AUC {float(auc):.3f}')
2739 print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
2740 f'on {str(devices)}')
2741
2742d2l.DATA_HUB['ctr'] = (d2l.DATA_URL + 'ctr.zip',
2743 'e18327c48c8e8e5c23da714dd614e390d369843f')

Callers

nothing calls this directly

Calls 3

addMethod · 0.95
stopMethod · 0.45
sumMethod · 0.45

Tested by

no test coverage detected