MCPcopy
hub / github.com/dmlc/dgl / _sample_neighbors

Function _sample_neighbors

python/dgl/sampling/neighbor.py:519–683  ·  view source on GitHub ↗
(
    g,
    nodes,
    fanout,
    edge_dir="in",
    prob=None,
    replace=False,
    copy_ndata=True,
    copy_edata=True,
    _dist_training=False,
    exclude_edges=None,
    fused=False,
    mapping=None,
)

Source from the content-addressed store, hash-verified

517
518
519def _sample_neighbors(
520 g,
521 nodes,
522 fanout,
523 edge_dir="in",
524 prob=None,
525 replace=False,
526 copy_ndata=True,
527 copy_edata=True,
528 _dist_training=False,
529 exclude_edges=None,
530 fused=False,
531 mapping=None,
532):
533 if not isinstance(nodes, dict):
534 if len(g.ntypes) > 1:
535 raise DGLError(
536 "Must specify node type when the graph is not homogeneous."
537 )
538 nodes = {g.ntypes[0]: nodes}
539
540 nodes = utils.prepare_tensor_dict(g, nodes, "nodes")
541 if len(nodes) == 0:
542 raise ValueError(
543 "Got an empty dictionary in the nodes argument. "
544 "Please pass in a dictionary with empty tensors as values instead."
545 )
546 device = utils.context_of(nodes)
547 ctx = utils.to_dgl_context(device)
548 nodes_all_types = []
549 for ntype in g.ntypes:
550 if ntype in nodes:
551 nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))
552 else:
553 nodes_all_types.append(nd.array([], ctx=ctx))
554
555 if isinstance(fanout, nd.NDArray):
556 fanout_array = fanout
557 else:
558 if not isinstance(fanout, dict):
559 fanout_array = [int(fanout)] * len(g.etypes)
560 else:
561 if len(fanout) != len(g.etypes):
562 raise DGLError(
563 "Fan-out must be specified for each edge type "
564 "if a dict is provided."
565 )
566 fanout_array = [None] * len(g.etypes)
567 for etype, value in fanout.items():
568 fanout_array[g.get_etype_id(etype)] = value
569 fanout_array = F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64))
570
571 prob_arrays = _prepare_edge_arrays(g, prob)
572
573 excluded_edges_all_t = []
574 if exclude_edges is not None:
575 if not isinstance(exclude_edges, dict):
576 if len(g.etypes) > 1:

Callers 2

sample_neighborsFunction · 0.70
sample_neighbors_fusedFunction · 0.70

Calls 11

DGLErrorClass · 0.85
_prepare_edge_arraysFunction · 0.85
DGLBlockClass · 0.85
DGLGraphClass · 0.85
appendMethod · 0.80
tousertensorMethod · 0.80
itemsMethod · 0.45
get_etype_idMethod · 0.45
keysMethod · 0.45
num_nodesMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected