(gidxA, A_weights, gidxB)
| 445 | |
| 446 | |
| 447 | def csrmask_real(gidxA, A_weights, gidxB): |
| 448 | B_weights = _csrmask(gidxA, A_weights, gidxB) |
| 449 | |
| 450 | def grad(dB_weights): |
| 451 | return _csrmask(gidxB, dB_weights, gidxA) |
| 452 | |
| 453 | return B_weights, grad |
| 454 | |
| 455 | |
| 456 | def csrmask(gidxA, A_weights, gidxB): |