MCPcopy
hub / github.com/dmlc/dgl / svd_pe

Function svd_pe

python/dgl/transforms/functional.py:4012–4086  ·  view source on GitHub ↗

r"""SVD-based Positional Encoding, as introduced in `Global Self-Attention as a Replacement for Graph Convolution `__ This function computes the largest :math:`k` singular values and corresponding left and right singular vectors to form position

(g, k, padding=False, random_flip=True)

Source from the content-addressed store, hash-verified

4010
4011
4012def svd_pe(g, k, padding=False, random_flip=True):
4013 r"""SVD-based Positional Encoding, as introduced in
4014 `Global Self-Attention as a Replacement for Graph Convolution
4015 <https://arxiv.org/pdf/2108.03348.pdf>`__
4016
4017 This function computes the largest :math:`k` singular values and
4018 corresponding left and right singular vectors to form positional encodings.
4019
4020 Parameters
4021 ----------
4022 g : DGLGraph
4023 A DGLGraph to be encoded, which must be a homogeneous one.
4024 k : int
4025 Number of largest singular values and corresponding singular vectors
4026 used for positional encoding.
4027 padding : bool, optional
4028 If False, raise an error when :math:`k > N`,
4029 where :math:`N` is the number of nodes in :attr:`g`.
4030 If True, add zero paddings in the end of encoding vectors when
4031 :math:`k > N`.
4032 Default : False.
4033 random_flip : bool, optional
4034 If True, randomly flip the signs of encoding vectors.
4035 Proposed to be activated during training for better generalization.
4036 Default : True.
4037
4038 Returns
4039 -------
4040 Tensor
4041 Return SVD-based positional encodings of shape :math:`(N, 2k)`.
4042
4043 Example
4044 -------
4045 >>> import dgl
4046
4047 >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
4048 >>> dgl.svd_pe(g, k=2, padding=False, random_flip=True)
4049 tensor([[-6.3246e-01, -1.1373e-07, -6.3246e-01, 0.0000e+00],
4050 [-6.3246e-01, 7.6512e-01, -6.3246e-01, -7.6512e-01],
4051 [ 6.3246e-01, 4.7287e-01, 6.3246e-01, -4.7287e-01],
4052 [-6.3246e-01, -7.6512e-01, -6.3246e-01, 7.6512e-01],
4053 [ 6.3246e-01, -4.7287e-01, 6.3246e-01, 4.7287e-01]])
4054 """
4055 n = g.num_nodes()
4056 if not padding and n < k:
4057 raise ValueError(
4058 "The number of singular values k must be no greater than the "
4059 "number of nodes n, but " + f"got {k} and {n} respectively."
4060 )
4061 a = g.adj_external(ctx=g.device, scipy_fmt="coo").toarray()
4062 u, d, vh = scipy.linalg.svd(a)
4063 v = vh.transpose()
4064 m = min(n, k)
4065 topm_u = u[:, 0:m]
4066 topm_v = v[:, 0:m]
4067 topm_sqrt_d = sparse.diags(np.sqrt(d[0:m]))
4068 encoding = np.concatenate(
4069 ((topm_u @ topm_sqrt_d), (topm_v @ topm_sqrt_d)), axis=1

Callers

nothing calls this directly

Calls 5

adj_externalMethod · 0.80
transposeMethod · 0.80
contextMethod · 0.80
minFunction · 0.50
num_nodesMethod · 0.45

Tested by

no test coverage detected