MCPcopy
hub / github.com/dmlc/dgl / __call__

Method __call__

python/dgl/transforms/module.py:321–375  ·  view source on GitHub ↗
(self, g)

Source from the content-addressed store, hash-verified

319 self.dist = Bernoulli(p)
320
321 def __call__(self, g):
322 # Fast path
323 if self.p == 0:
324 return g
325
326 for node_feat_name in self.node_feat_names:
327 if isinstance(g.ndata[node_feat_name], torch.Tensor):
328 feat_mask = self.dist.sample(
329 torch.Size(
330 [
331 g.ndata[node_feat_name].shape[-1],
332 ]
333 )
334 )
335 g.ndata[node_feat_name][:, feat_mask.bool().to(g.device)] = 0
336
337 else:
338 for ntype in g.ndata[node_feat_name].keys():
339 mask_shape = g.ndata[node_feat_name][ntype].shape[-1]
340 feat_mask = self.dist.sample(
341 torch.Size(
342 [
343 mask_shape,
344 ]
345 )
346 )
347 g.ndata[node_feat_name][ntype][
348 :, feat_mask.bool().to(g.device)
349 ] = 0
350
351 for edge_feat_name in self.edge_feat_names:
352 if isinstance(g.edata[edge_feat_name], torch.Tensor):
353 feat_mask = self.dist.sample(
354 torch.Size(
355 [
356 g.edata[edge_feat_name].shape[-1],
357 ]
358 )
359 )
360 g.edata[edge_feat_name][:, feat_mask.bool().to(g.device)] = 0
361
362 else:
363 for etype in g.edata[edge_feat_name].keys():
364 mask_shape = g.edata[edge_feat_name][etype].shape[-1]
365 feat_mask = self.dist.sample(
366 torch.Size(
367 [
368 mask_shape,
369 ]
370 )
371 )
372 g.edata[edge_feat_name][etype][
373 :, feat_mask.bool().to(g.device)
374 ] = 0
375 return g
376
377
378class RandomWalkPE(BaseTransform):

Callers

nothing calls this directly

Calls 3

sampleMethod · 0.45
toMethod · 0.45
keysMethod · 0.45

Tested by

no test coverage detected