MCPcopy
hub / github.com/dmlc/dgl / infer_broadcast_shape

Function infer_broadcast_shape

python/dgl/_sparse_ops.py:10–60  ·  view source on GitHub ↗

r"""Check the shape validity, and infer the output shape given input shape and operator. Note the both :attr:`shp1`, :attr:`shp2` and the returned shape are feature shapes (i.e. we remove the first dimension, which correspond to graph statistics such as number of nodes, number of edges,

(op, shp1, shp2)

Source from the content-addressed store, hash-verified

8
9
10def infer_broadcast_shape(op, shp1, shp2):
11 r"""Check the shape validity, and infer the output shape given input shape and operator.
12 Note the both :attr:`shp1`, :attr:`shp2` and the returned shape are feature
13 shapes (i.e. we remove the first dimension, which correspond to graph statistics
14 such as number of nodes, number of edges, etc.).
15
16 We allow applying op on operands with different shapes, according to the
17 broadcasting semantics of Numpy/Scipy:
18 https://numpy.org/doc/stable/user/basics.broadcasting.html
19
20 Parameters
21 ----------
22 op : str
23 The binary op's name, could be `add`, `sub`, `mul`, `div`, `dot`, `copy_lhs`, `copy_rhs`.
24 shp1 : tuple[int]
25 The shape of lhs operand.
26 shp2 : tuple[int]
27 The shape of rhs operand.
28
29 Returns
30 -------
31 tuple[int]
32 shape after broadcasting
33 """
34 pad_shp1, pad_shp2 = shp1, shp2
35 if op == "dot":
36 if shp1[-1] != shp2[-1]:
37 raise DGLError(
38 "Dot operator is only available for arrays with the "
39 "same size on last dimension, but got {} and {}.".format(
40 shp1, shp2
41 )
42 )
43 if op == "copy_lhs":
44 return shp1
45 if op == "copy_rhs":
46 return shp2
47 # operands are padded to have the same dimensionality with leading 1's.
48 if len(shp1) > len(shp2):
49 pad_shp2 = (1,) * (len(shp1) - len(shp2)) + shp2
50 elif len(shp1) < len(shp2):
51 pad_shp1 = (1,) * (len(shp2) - len(shp1)) + shp1
52 for d1, d2 in zip(pad_shp1, pad_shp2):
53 if d1 != d2 and d1 != 1 and d2 != 1:
54 raise DGLError(
55 "Feature shapes {} and {} are not valid for broadcasting.".format(
56 shp1, shp2
57 )
58 )
59 rst = tuple(max(d1, d2) for d1, d2 in zip(pad_shp1, pad_shp2))
60 return rst[:-1] + (1,) if op == "dot" else rst
61
62
63def to_dgl_nd(x):

Callers 4

_gspmmFunction · 0.85
_gspmm_heteroFunction · 0.85
_gsddmmFunction · 0.85
_gsddmm_heteroFunction · 0.85

Calls 3

DGLErrorClass · 0.85
formatMethod · 0.80
maxFunction · 0.50

Tested by

no test coverage detected