r"""SVD-based Positional Encoding, as introduced in `Global Self-Attention as a Replacement for Graph Convolution `__ This function computes the largest :math:`k` singular values and corresponding left and right singular vectors to form position
(g, k, padding=False, random_flip=True)
| 4010 | |
| 4011 | |
| 4012 | def svd_pe(g, k, padding=False, random_flip=True): |
| 4013 | r"""SVD-based Positional Encoding, as introduced in |
| 4014 | `Global Self-Attention as a Replacement for Graph Convolution |
| 4015 | <https://arxiv.org/pdf/2108.03348.pdf>`__ |
| 4016 | |
| 4017 | This function computes the largest :math:`k` singular values and |
| 4018 | corresponding left and right singular vectors to form positional encodings. |
| 4019 | |
| 4020 | Parameters |
| 4021 | ---------- |
| 4022 | g : DGLGraph |
| 4023 | A DGLGraph to be encoded, which must be a homogeneous one. |
| 4024 | k : int |
| 4025 | Number of largest singular values and corresponding singular vectors |
| 4026 | used for positional encoding. |
| 4027 | padding : bool, optional |
| 4028 | If False, raise an error when :math:`k > N`, |
| 4029 | where :math:`N` is the number of nodes in :attr:`g`. |
| 4030 | If True, add zero paddings in the end of encoding vectors when |
| 4031 | :math:`k > N`. |
| 4032 | Default : False. |
| 4033 | random_flip : bool, optional |
| 4034 | If True, randomly flip the signs of encoding vectors. |
| 4035 | Proposed to be activated during training for better generalization. |
| 4036 | Default : True. |
| 4037 | |
| 4038 | Returns |
| 4039 | ------- |
| 4040 | Tensor |
| 4041 | Return SVD-based positional encodings of shape :math:`(N, 2k)`. |
| 4042 | |
| 4043 | Example |
| 4044 | ------- |
| 4045 | >>> import dgl |
| 4046 | |
| 4047 | >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4])) |
| 4048 | >>> dgl.svd_pe(g, k=2, padding=False, random_flip=True) |
| 4049 | tensor([[-6.3246e-01, -1.1373e-07, -6.3246e-01, 0.0000e+00], |
| 4050 | [-6.3246e-01, 7.6512e-01, -6.3246e-01, -7.6512e-01], |
| 4051 | [ 6.3246e-01, 4.7287e-01, 6.3246e-01, -4.7287e-01], |
| 4052 | [-6.3246e-01, -7.6512e-01, -6.3246e-01, 7.6512e-01], |
| 4053 | [ 6.3246e-01, -4.7287e-01, 6.3246e-01, 4.7287e-01]]) |
| 4054 | """ |
| 4055 | n = g.num_nodes() |
| 4056 | if not padding and n < k: |
| 4057 | raise ValueError( |
| 4058 | "The number of singular values k must be no greater than the " |
| 4059 | "number of nodes n, but " + f"got {k} and {n} respectively." |
| 4060 | ) |
| 4061 | a = g.adj_external(ctx=g.device, scipy_fmt="coo").toarray() |
| 4062 | u, d, vh = scipy.linalg.svd(a) |
| 4063 | v = vh.transpose() |
| 4064 | m = min(n, k) |
| 4065 | topm_u = u[:, 0:m] |
| 4066 | topm_v = v[:, 0:m] |
| 4067 | topm_sqrt_d = sparse.diags(np.sqrt(d[0:m])) |
| 4068 | encoding = np.concatenate( |
| 4069 | ((topm_u @ topm_sqrt_d), (topm_v @ topm_sqrt_d)), axis=1 |
nothing calls this directly
no test coverage detected