MCPcopy
hub / github.com/dmlc/dgl / sample_blocks

Method sample_blocks

python/dgl/dataloading/labor_sampler.py:218–255  ·  view source on GitHub ↗
(self, g, seed_nodes, exclude_eids=None)

Source from the content-addressed store, hash-verified

216 self.cnt[0] = 0
217
218 def sample_blocks(self, g, seed_nodes, exclude_eids=None):
219 output_nodes = seed_nodes
220 blocks = []
221 for i, fanout in enumerate(reversed(self.fanouts)):
222 random_seed_i = F.zerocopy_to_dgl_ndarray(
223 self.random_seed + (i if not self.layer_dependency else 0)
224 )
225 if self.cnt[1] <= 1:
226 seed2_contr = 0
227 else:
228 seed2_contr = ((self.cnt[0] % self.cnt[1]) / self.cnt[1]).item()
229 frontier, importances = g.sample_labors(
230 seed_nodes,
231 fanout,
232 edge_dir=self.edge_dir,
233 prob=self.prob,
234 importance_sampling=self.importance_sampling,
235 random_seed=random_seed_i,
236 seed2_contribution=seed2_contr,
237 output_device=self.output_device,
238 exclude_edges=exclude_eids,
239 )
240 eid = frontier.edata[EID]
241 block = to_block(
242 frontier, seed_nodes, include_dst_in_src=True, src_nodes=None
243 )
244 block.edata[EID] = eid
245 if len(g.canonical_etypes) > 1:
246 for etype, importance in zip(g.canonical_etypes, importances):
247 if importance.shape[0] == block.num_edges(etype):
248 block.edata["edge_weights"][etype] = importance
249 elif importances[0].shape[0] == block.num_edges():
250 block.edata["edge_weights"] = importances[0]
251 seed_nodes = block.srcdata[NID]
252 blocks.insert(0, block)
253
254 self.set_seed()
255 return seed_nodes, output_nodes, blocks

Callers

nothing calls this directly

Calls 3

set_seedMethod · 0.95
to_blockFunction · 0.90
num_edgesMethod · 0.45

Tested by

no test coverage detected