MCPcopy
hub / github.com/dmlc/dgl / _sample_labors

Function _sample_labors

python/dgl/sampling/labor.py:244–356  ·  view source on GitHub ↗
(
    g,
    nodes,
    fanout,
    edge_dir="in",
    prob=None,
    importance_sampling=0,
    random_seed=None,
    seed2_contribution=0,
    copy_ndata=True,
    copy_edata=True,
    exclude_edges=None,
)

Source from the content-addressed store, hash-verified

242
243
244def _sample_labors(
245 g,
246 nodes,
247 fanout,
248 edge_dir="in",
249 prob=None,
250 importance_sampling=0,
251 random_seed=None,
252 seed2_contribution=0,
253 copy_ndata=True,
254 copy_edata=True,
255 exclude_edges=None,
256):
257 if random_seed is None:
258 random_seed = F.to_dgl_nd(choice(1e18, 1))
259 if not isinstance(nodes, dict):
260 if len(g.ntypes) > 1:
261 raise DGLError(
262 "Must specify node type when the graph is not homogeneous."
263 )
264 nodes = {g.ntypes[0]: nodes}
265
266 nodes = utils.prepare_tensor_dict(g, nodes, "nodes")
267 if len(nodes) == 0:
268 raise ValueError(
269 "Got an empty dictionary in the nodes argument. "
270 "Please pass in a dictionary with empty tensors as values instead."
271 )
272 ctx = utils.to_dgl_context(F.context(next(iter(nodes.values()))))
273 nodes_all_types = []
274 # nids_all_types is needed if one wants labor to work for subgraphs whose vertices have
275 # been renamed and the rolled randoms should be rolled for global vertex ids.
276 # It is disabled for now below by passing empty ndarrays.
277 nids_all_types = [nd.array([], ctx=ctx) for _ in g.ntypes]
278 for ntype in g.ntypes:
279 if ntype in nodes:
280 nodes_all_types.append(F.to_dgl_nd(nodes[ntype]))
281 else:
282 nodes_all_types.append(nd.array([], ctx=ctx))
283
284 if isinstance(fanout, nd.NDArray):
285 fanout_array = fanout
286 else:
287 if not isinstance(fanout, dict):
288 fanout_array = [int(fanout)] * len(g.etypes)
289 else:
290 if len(fanout) != len(g.etypes):
291 raise DGLError(
292 "Fan-out must be specified for each edge type "
293 "if a dict is provided."
294 )
295 fanout_array = [None] * len(g.etypes)
296 for etype, value in fanout.items():
297 fanout_array[g.get_etype_id(etype)] = value
298 fanout_array = F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64))
299
300 if (
301 isinstance(prob, list)

Callers 1

sample_laborsFunction · 0.85

Calls 9

choiceFunction · 0.85
DGLErrorClass · 0.85
DGLGraphClass · 0.85
contextMethod · 0.80
appendMethod · 0.80
valuesMethod · 0.45
itemsMethod · 0.45
get_etype_idMethod · 0.45
cpuMethod · 0.45

Tested by

no test coverage detected