(O)
| 218 | |
| 219 | @pytest.mark.parametrize("O", [1, 2, 8]) |
| 220 | def test_rgcn(O): |
| 221 | etype = [] |
| 222 | g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to( |
| 223 | F.ctx() |
| 224 | ) |
| 225 | # 5 etypes |
| 226 | R = 5 |
| 227 | for i in range(g.num_edges()): |
| 228 | etype.append(i % 5) |
| 229 | B = 2 |
| 230 | I = 10 |
| 231 | |
| 232 | rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) |
| 233 | rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True) |
| 234 | rgc_basis_low.weight = rgc_basis.weight |
| 235 | rgc_basis_low.w_comp = rgc_basis.w_comp |
| 236 | rgc_basis_low.loop_weight = rgc_basis.loop_weight |
| 237 | h = tf.random.normal((100, I)) |
| 238 | r = tf.constant(etype) |
| 239 | h_new = rgc_basis(g, h, r) |
| 240 | h_new_low = rgc_basis_low(g, h, r) |
| 241 | assert list(h_new.shape) == [100, O] |
| 242 | assert list(h_new_low.shape) == [100, O] |
| 243 | assert F.allclose(h_new, h_new_low) |
| 244 | |
| 245 | if O % B == 0: |
| 246 | rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) |
| 247 | rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True) |
| 248 | rgc_bdd_low.weight = rgc_bdd.weight |
| 249 | rgc_bdd_low.loop_weight = rgc_bdd.loop_weight |
| 250 | h = tf.random.normal((100, I)) |
| 251 | r = tf.constant(etype) |
| 252 | h_new = rgc_bdd(g, h, r) |
| 253 | h_new_low = rgc_bdd_low(g, h, r) |
| 254 | assert list(h_new.shape) == [100, O] |
| 255 | assert list(h_new_low.shape) == [100, O] |
| 256 | assert F.allclose(h_new, h_new_low) |
| 257 | |
| 258 | # with norm |
| 259 | norm = tf.zeros((g.num_edges(), 1)) |
| 260 | |
| 261 | rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) |
| 262 | rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True) |
| 263 | rgc_basis_low.weight = rgc_basis.weight |
| 264 | rgc_basis_low.w_comp = rgc_basis.w_comp |
| 265 | rgc_basis_low.loop_weight = rgc_basis.loop_weight |
| 266 | h = tf.random.normal((100, I)) |
| 267 | r = tf.constant(etype) |
| 268 | h_new = rgc_basis(g, h, r, norm) |
| 269 | h_new_low = rgc_basis_low(g, h, r, norm) |
| 270 | assert list(h_new.shape) == [100, O] |
| 271 | assert list(h_new_low.shape) == [100, O] |
| 272 | assert F.allclose(h_new, h_new_low) |
| 273 | |
| 274 | if O % B == 0: |
| 275 | rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) |
| 276 | rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True) |
| 277 | rgc_bdd_low.weight = rgc_bdd.weight |
no test coverage detected