Sparse optimizer state initializer Parameters ---------- shape : tuple of ints The shape of the state tensor dtype : torch dtype The data type of the state tensor
(shape, dtype)
| 449 | |
| 450 | |
| 451 | def initializer(shape, dtype): |
| 452 | """Sparse optimizer state initializer |
| 453 | |
| 454 | Parameters |
| 455 | ---------- |
| 456 | shape : tuple of ints |
| 457 | The shape of the state tensor |
| 458 | dtype : torch dtype |
| 459 | The data type of the state tensor |
| 460 | """ |
| 461 | arr = th.zeros(shape, dtype=dtype) |
| 462 | return arr |
| 463 | |
| 464 | |
| 465 | class SparseAdagrad(DistSparseGradOptimizer): |
no outgoing calls
no test coverage detected