MCPcopy Index your code
hub / github.com/dmlc/dgl / run_client

Function run_client

tests/python/pytorch/distributed/optim/test_dist_optim.py:93–150  ·  view source on GitHub ↗
(graph_name, cli_id, part_id, server_count)

Source from the content-addressed store, hash-verified

91
92
93def run_client(graph_name, cli_id, part_id, server_count):
94 device = F.ctx()
95 time.sleep(5)
96 os.environ["DGL_NUM_SERVER"] = str(server_count)
97 dgl.distributed.initialize("optim_ip_config.txt")
98 gpb, graph_name, _, _ = load_partition_book(
99 "/tmp/dist_graph/{}.json".format(graph_name), part_id
100 )
101 g = DistGraph(graph_name, gpb=gpb)
102 policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book())
103 num_nodes = g.num_nodes()
104 emb_dim = 4
105 dgl_emb = DistEmbedding(
106 num_nodes,
107 emb_dim,
108 name="optim",
109 init_func=initializer,
110 part_policy=policy,
111 )
112 dgl_emb_zero = DistEmbedding(
113 num_nodes,
114 emb_dim,
115 name="optim-zero",
116 init_func=initializer,
117 part_policy=policy,
118 )
119 dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
120 dgl_adam._world_size = 1
121 dgl_adam._rank = 0
122
123 torch_emb = th.nn.Embedding(num_nodes, emb_dim, sparse=True)
124 torch_emb_zero = th.nn.Embedding(num_nodes, emb_dim, sparse=True)
125 th.manual_seed(0)
126 th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
127 th.manual_seed(0)
128 th.nn.init.uniform_(torch_emb_zero.weight, 0, 1.0)
129 torch_adam = th.optim.SparseAdam(
130 list(torch_emb.parameters()) + list(torch_emb_zero.parameters()),
131 lr=0.01,
132 )
133
134 labels = th.ones((4,)).long()
135 idx = th.randint(0, num_nodes, size=(4,))
136 dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
137 torch_value = torch_emb(idx)
138 torch_adam.zero_grad()
139 torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
140 torch_loss.backward()
141 torch_adam.step()
142
143 dgl_adam.zero_grad()
144 dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
145 dgl_loss.backward()
146 dgl_adam.step()
147
148 assert F.allclose(
149 dgl_emb.weight[0 : num_nodes // 2], torch_emb.weight[0 : num_nodes // 2]
150 )

Callers

nothing calls this directly

Calls 15

get_partition_bookMethod · 0.95
num_nodesMethod · 0.95
load_partition_bookFunction · 0.90
DistGraphClass · 0.90
DistEmbeddingClass · 0.90
SparseAdamClass · 0.90
formatMethod · 0.80
parametersMethod · 0.80
ctxMethod · 0.45
longMethod · 0.45
toMethod · 0.45
deviceMethod · 0.45

Tested by

no test coverage detected