(
g,
nodes,
fanout,
edge_dir="in",
prob=None,
replace=False,
copy_ndata=True,
copy_edata=True,
_dist_training=False,
exclude_edges=None,
fused=False,
mapping=None,
)
| 517 | |
| 518 | |
| 519 | def _sample_neighbors( |
| 520 | g, |
| 521 | nodes, |
| 522 | fanout, |
| 523 | edge_dir="in", |
| 524 | prob=None, |
| 525 | replace=False, |
| 526 | copy_ndata=True, |
| 527 | copy_edata=True, |
| 528 | _dist_training=False, |
| 529 | exclude_edges=None, |
| 530 | fused=False, |
| 531 | mapping=None, |
| 532 | ): |
| 533 | if not isinstance(nodes, dict): |
| 534 | if len(g.ntypes) > 1: |
| 535 | raise DGLError( |
| 536 | "Must specify node type when the graph is not homogeneous." |
| 537 | ) |
| 538 | nodes = {g.ntypes[0]: nodes} |
| 539 | |
| 540 | nodes = utils.prepare_tensor_dict(g, nodes, "nodes") |
| 541 | if len(nodes) == 0: |
| 542 | raise ValueError( |
| 543 | "Got an empty dictionary in the nodes argument. " |
| 544 | "Please pass in a dictionary with empty tensors as values instead." |
| 545 | ) |
| 546 | device = utils.context_of(nodes) |
| 547 | ctx = utils.to_dgl_context(device) |
| 548 | nodes_all_types = [] |
| 549 | for ntype in g.ntypes: |
| 550 | if ntype in nodes: |
| 551 | nodes_all_types.append(F.to_dgl_nd(nodes[ntype])) |
| 552 | else: |
| 553 | nodes_all_types.append(nd.array([], ctx=ctx)) |
| 554 | |
| 555 | if isinstance(fanout, nd.NDArray): |
| 556 | fanout_array = fanout |
| 557 | else: |
| 558 | if not isinstance(fanout, dict): |
| 559 | fanout_array = [int(fanout)] * len(g.etypes) |
| 560 | else: |
| 561 | if len(fanout) != len(g.etypes): |
| 562 | raise DGLError( |
| 563 | "Fan-out must be specified for each edge type " |
| 564 | "if a dict is provided." |
| 565 | ) |
| 566 | fanout_array = [None] * len(g.etypes) |
| 567 | for etype, value in fanout.items(): |
| 568 | fanout_array[g.get_etype_id(etype)] = value |
| 569 | fanout_array = F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64)) |
| 570 | |
| 571 | prob_arrays = _prepare_edge_arrays(g, prob) |
| 572 | |
| 573 | excluded_edges_all_t = [] |
| 574 | if exclude_edges is not None: |
| 575 | if not isinstance(exclude_edges, dict): |
| 576 | if len(g.etypes) > 1: |
no test coverage detected