MCPcopy
hub / github.com/dmlc/dgl / test_DeepWalk

Function test_DeepWalk

tests/python/pytorch/nn/test_nn.py:2421–2448  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

2419
2420
2421def test_DeepWalk():
2422 dev = F.ctx()
2423 g = dgl.graph(([0, 1, 2, 1, 2, 0], [1, 2, 0, 0, 1, 2]))
2424 model = nn.DeepWalk(
2425 g, emb_dim=8, walk_length=2, window_size=1, fast_neg=True, sparse=True
2426 )
2427 model = model.to(dev)
2428 dataloader = DataLoader(
2429 torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
2430 )
2431 optim = SparseAdam(model.parameters(), lr=0.01)
2432 walk = next(iter(dataloader)).to(dev)
2433 loss = model(walk)
2434 loss.backward()
2435 optim.step()
2436
2437 model = nn.DeepWalk(
2438 g, emb_dim=8, walk_length=2, window_size=1, fast_neg=False, sparse=False
2439 )
2440 model = model.to(dev)
2441 dataloader = DataLoader(
2442 torch.arange(g.num_nodes()), batch_size=16, collate_fn=model.sample
2443 )
2444 optim = Adam(model.parameters(), lr=0.01)
2445 walk = next(iter(dataloader)).to(dev)
2446 loss = model(walk)
2447 loss.backward()
2448 optim.step()
2449
2450
2451@pytest.mark.parametrize("max_degree", [2, 6])

Callers

nothing calls this directly

Calls 9

DataLoaderClass · 0.90
SparseAdamClass · 0.90
parametersMethod · 0.80
ctxMethod · 0.45
graphMethod · 0.45
toMethod · 0.45
num_nodesMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected