(
self, module, beta=0.9999, start_iteration=1000,
remove_wn_wrapper=True
)
| 45 | remove_sn (bool): Whether we remove the spectral norm when we it. |
| 46 | """ |
| 47 | def __init__( |
| 48 | self, module, beta=0.9999, start_iteration=1000, |
| 49 | remove_wn_wrapper=True |
| 50 | ): |
| 51 | super(ModelAverage, self).__init__() |
| 52 | self.module = module |
| 53 | # A shallow copy creates a new object which stores the reference of |
| 54 | # the original elements. |
| 55 | # A deep copy creates a new object and recursively adds the copies of |
| 56 | # nested objects present in the original elements. |
| 57 | self.averaged_model = copy.deepcopy(self.module).to('cuda') |
| 58 | self.beta = beta |
| 59 | self.remove_wn_wrapper = remove_wn_wrapper |
| 60 | self.start_iteration = start_iteration |
| 61 | # This buffer is to track how many iterations has the model been |
| 62 | # trained for. We will ignore the first $(start_iterations) and start |
| 63 | # the averaging after. |
| 64 | self.register_buffer('num_updates_tracked', |
| 65 | torch.tensor(0, dtype=torch.long)) |
| 66 | self.num_updates_tracked = self.num_updates_tracked.to('cuda') |
| 67 | # if self.remove_sn: |
| 68 | # # If we want to remove the spectral norm, we first copy the |
| 69 | # # weights to the moving average model. |
| 70 | # self.copy_s2t() |
| 71 | # |
| 72 | # def fn_remove_sn(m): |
| 73 | # r"""Remove spectral norm.""" |
| 74 | # if hasattr(m, 'weight_orig'): |
| 75 | # remove_spectral_norm(m) |
| 76 | # |
| 77 | # self.averaged_model.apply(fn_remove_sn) |
| 78 | # self.dim = 0 |
| 79 | if self.remove_wn_wrapper: |
| 80 | self.copy_s2t() |
| 81 | |
| 82 | self.averaged_model.apply(remove_weight_norms) |
| 83 | self.dim = 0 |
| 84 | else: |
| 85 | self.averaged_model.eval() |
| 86 | |
| 87 | # Averaged model does not require grad. |
| 88 | requires_grad(self.averaged_model, False) |
| 89 | |
| 90 | def forward(self, *inputs, **kwargs): |
| 91 | r"""PyTorch module forward function overload.""" |
nothing calls this directly
no test coverage detected