(
g,
nodes,
fanout,
edge_dir="in",
prob=None,
importance_sampling=0,
random_seed=None,
seed2_contribution=0,
copy_ndata=True,
copy_edata=True,
exclude_edges=None,
)
| 242 | |
| 243 | |
| 244 | def _sample_labors( |
| 245 | g, |
| 246 | nodes, |
| 247 | fanout, |
| 248 | edge_dir="in", |
| 249 | prob=None, |
| 250 | importance_sampling=0, |
| 251 | random_seed=None, |
| 252 | seed2_contribution=0, |
| 253 | copy_ndata=True, |
| 254 | copy_edata=True, |
| 255 | exclude_edges=None, |
| 256 | ): |
| 257 | if random_seed is None: |
| 258 | random_seed = F.to_dgl_nd(choice(1e18, 1)) |
| 259 | if not isinstance(nodes, dict): |
| 260 | if len(g.ntypes) > 1: |
| 261 | raise DGLError( |
| 262 | "Must specify node type when the graph is not homogeneous." |
| 263 | ) |
| 264 | nodes = {g.ntypes[0]: nodes} |
| 265 | |
| 266 | nodes = utils.prepare_tensor_dict(g, nodes, "nodes") |
| 267 | if len(nodes) == 0: |
| 268 | raise ValueError( |
| 269 | "Got an empty dictionary in the nodes argument. " |
| 270 | "Please pass in a dictionary with empty tensors as values instead." |
| 271 | ) |
| 272 | ctx = utils.to_dgl_context(F.context(next(iter(nodes.values())))) |
| 273 | nodes_all_types = [] |
| 274 | # nids_all_types is needed if one wants labor to work for subgraphs whose vertices have |
| 275 | # been renamed and the rolled randoms should be rolled for global vertex ids. |
| 276 | # It is disabled for now below by passing empty ndarrays. |
| 277 | nids_all_types = [nd.array([], ctx=ctx) for _ in g.ntypes] |
| 278 | for ntype in g.ntypes: |
| 279 | if ntype in nodes: |
| 280 | nodes_all_types.append(F.to_dgl_nd(nodes[ntype])) |
| 281 | else: |
| 282 | nodes_all_types.append(nd.array([], ctx=ctx)) |
| 283 | |
| 284 | if isinstance(fanout, nd.NDArray): |
| 285 | fanout_array = fanout |
| 286 | else: |
| 287 | if not isinstance(fanout, dict): |
| 288 | fanout_array = [int(fanout)] * len(g.etypes) |
| 289 | else: |
| 290 | if len(fanout) != len(g.etypes): |
| 291 | raise DGLError( |
| 292 | "Fan-out must be specified for each edge type " |
| 293 | "if a dict is provided." |
| 294 | ) |
| 295 | fanout_array = [None] * len(g.etypes) |
| 296 | for etype, value in fanout.items(): |
| 297 | fanout_array[g.get_etype_id(etype)] = value |
| 298 | fanout_array = F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64)) |
| 299 | |
| 300 | if ( |
| 301 | isinstance(prob, list) |
no test coverage detected