MCPcopy Index your code
hub / github.com/dmlc/dgl / prepare_tensor

Function prepare_tensor

python/dgl/utils/checks.py:12–70  ·  view source on GitHub ↗

Convert the data to ID tensor and check its ID type and context. If the data is already in tensor type, raise error if its ID type and context does not match the graph's. Otherwise, convert it to tensor type of the graph's ID type and ctx and return. Parameters ----------

(g, data, name)

Source from the content-addressed store, hash-verified

10
11
12def prepare_tensor(g, data, name):
13 """Convert the data to ID tensor and check its ID type and context.
14
15 If the data is already in tensor type, raise error if its ID type
16 and context does not match the graph's.
17 Otherwise, convert it to tensor type of the graph's ID type and
18 ctx and return.
19
20 Parameters
21 ----------
22 g : DGLGraph
23 Graph.
24 data : int, iterable of int, tensor
25 Data.
26 name : str
27 Name of the data.
28
29 Returns
30 -------
31 Tensor
32 Data in tensor object.
33 """
34 if F.is_tensor(data):
35 if F.dtype(data) != g.idtype:
36 raise DGLError(
37 f'Expect argument "{name}" to have data type {g.idtype}. '
38 f"But got {F.dtype(data)}."
39 )
40 if F.context(data) != g.device and not g.is_pinned():
41 raise DGLError(
42 f'Expect argument "{name}" to have device {g.device}. '
43 f"But got {F.context(data)}."
44 )
45 ret = data
46 else:
47 data = F.tensor(data)
48 if not (
49 F.ndim(data) > 0 and F.shape(data)[0] == 0
50 ) and F.dtype( # empty tensor
51 data
52 ) not in (
53 F.int32,
54 F.int64,
55 ):
56 raise DGLError(
57 'Expect argument "{}" to have data type int32 or int64,'
58 " but got {}.".format(name, F.dtype(data))
59 )
60 ret = F.copy_to(F.astype(data, g.idtype), g.device)
61
62 if F.ndim(ret) == 0:
63 ret = F.unsqueeze(ret, 0)
64 if F.ndim(ret) > 1:
65 raise DGLError(
66 'Expect a 1-D tensor for argument "{}". But got {}.'.format(
67 name, ret
68 )
69 )

Callers 3

prepare_tensor_dictFunction · 0.85
prepare_tensor_or_dictFunction · 0.85
parse_edges_arg_to_eidFunction · 0.85

Calls 8

DGLErrorClass · 0.85
contextMethod · 0.80
formatMethod · 0.80
dtypeMethod · 0.45
is_pinnedMethod · 0.45
shapeMethod · 0.45
copy_toMethod · 0.45
astypeMethod · 0.45

Tested by

no test coverage detected