Load weight.
(param: torch.nn.Parameter, loaded_weight: torch.Tensor, **kwargs)
| 17 | |
| 18 | |
| 19 | def load_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor, **kwargs): |
| 20 | """Load weight.""" |
| 21 | if hasattr(param, 'weight_loader'): |
| 22 | param.weight_loader(param, loaded_weight, **kwargs) |
| 23 | else: |
| 24 | assert len(kwargs) == 0 |
| 25 | default_weight_loader(param, loaded_weight) |
| 26 | |
| 27 | |
| 28 | def default_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor): |
no test coverage detected