(graph_name, cli_id, part_id, server_count)
| 91 | |
| 92 | |
| 93 | def run_client(graph_name, cli_id, part_id, server_count): |
| 94 | device = F.ctx() |
| 95 | time.sleep(5) |
| 96 | os.environ["DGL_NUM_SERVER"] = str(server_count) |
| 97 | dgl.distributed.initialize("optim_ip_config.txt") |
| 98 | gpb, graph_name, _, _ = load_partition_book( |
| 99 | "/tmp/dist_graph/{}.json".format(graph_name), part_id |
| 100 | ) |
| 101 | g = DistGraph(graph_name, gpb=gpb) |
| 102 | policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book()) |
| 103 | num_nodes = g.num_nodes() |
| 104 | emb_dim = 4 |
| 105 | dgl_emb = DistEmbedding( |
| 106 | num_nodes, |
| 107 | emb_dim, |
| 108 | name="optim", |
| 109 | init_func=initializer, |
| 110 | part_policy=policy, |
| 111 | ) |
| 112 | dgl_emb_zero = DistEmbedding( |
| 113 | num_nodes, |
| 114 | emb_dim, |
| 115 | name="optim-zero", |
| 116 | init_func=initializer, |
| 117 | part_policy=policy, |
| 118 | ) |
| 119 | dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01) |
| 120 | dgl_adam._world_size = 1 |
| 121 | dgl_adam._rank = 0 |
| 122 | |
| 123 | torch_emb = th.nn.Embedding(num_nodes, emb_dim, sparse=True) |
| 124 | torch_emb_zero = th.nn.Embedding(num_nodes, emb_dim, sparse=True) |
| 125 | th.manual_seed(0) |
| 126 | th.nn.init.uniform_(torch_emb.weight, 0, 1.0) |
| 127 | th.manual_seed(0) |
| 128 | th.nn.init.uniform_(torch_emb_zero.weight, 0, 1.0) |
| 129 | torch_adam = th.optim.SparseAdam( |
| 130 | list(torch_emb.parameters()) + list(torch_emb_zero.parameters()), |
| 131 | lr=0.01, |
| 132 | ) |
| 133 | |
| 134 | labels = th.ones((4,)).long() |
| 135 | idx = th.randint(0, num_nodes, size=(4,)) |
| 136 | dgl_value = dgl_emb(idx, device).to(th.device("cpu")) |
| 137 | torch_value = torch_emb(idx) |
| 138 | torch_adam.zero_grad() |
| 139 | torch_loss = th.nn.functional.cross_entropy(torch_value, labels) |
| 140 | torch_loss.backward() |
| 141 | torch_adam.step() |
| 142 | |
| 143 | dgl_adam.zero_grad() |
| 144 | dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels) |
| 145 | dgl_loss.backward() |
| 146 | dgl_adam.step() |
| 147 | |
| 148 | assert F.allclose( |
| 149 | dgl_emb.weight[0 : num_nodes // 2], torch_emb.weight[0 : num_nodes // 2] |
| 150 | ) |
nothing calls this directly
no test coverage detected