(
self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights
)
| 454 | return nrows, ncols, C_indptr, C_indices, C_eids, C_weights |
| 455 | |
| 456 | def backward( |
| 457 | self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights |
| 458 | ): |
| 459 | # Only the last argument is meaningful. |
| 460 | gidxC = self.backward_cache |
| 461 | A_weights, B_weights = self.saved_tensors |
| 462 | dgidxA, dA_weights = _csrmm( |
| 463 | gidxC, |
| 464 | dC_weights, |
| 465 | self.gidxB.reverse(), |
| 466 | B_weights, |
| 467 | self.gidxA.number_of_ntypes(), |
| 468 | ) |
| 469 | dgidxB, dB_weights = _csrmm( |
| 470 | self.gidxA.reverse(), |
| 471 | A_weights, |
| 472 | gidxC, |
| 473 | dC_weights, |
| 474 | self.gidxB.number_of_ntypes(), |
| 475 | ) |
| 476 | dA_weights = _csrmask(dgidxA, dA_weights, self.gidxA) |
| 477 | dB_weights = _csrmask(dgidxB, dB_weights, self.gidxB) |
| 478 | return dA_weights, dB_weights |
| 479 | |
| 480 | |
| 481 | def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes): |
nothing calls this directly
no test coverage detected