MCPcopy
hub / github.com/NVlabs/imaginaire / __init__

Method __init__

imaginaire/utils/model_average.py:47–88  ·  view source on GitHub ↗
(
            self, module, beta=0.9999, start_iteration=1000,
            remove_wn_wrapper=True
    )

Source from the content-addressed store, hash-verified

45 remove_sn (bool): Whether we remove the spectral norm when we it.
46 """
47 def __init__(
48 self, module, beta=0.9999, start_iteration=1000,
49 remove_wn_wrapper=True
50 ):
51 super(ModelAverage, self).__init__()
52 self.module = module
53 # A shallow copy creates a new object which stores the reference of
54 # the original elements.
55 # A deep copy creates a new object and recursively adds the copies of
56 # nested objects present in the original elements.
57 self.averaged_model = copy.deepcopy(self.module).to('cuda')
58 self.beta = beta
59 self.remove_wn_wrapper = remove_wn_wrapper
60 self.start_iteration = start_iteration
61 # This buffer is to track how many iterations has the model been
62 # trained for. We will ignore the first $(start_iterations) and start
63 # the averaging after.
64 self.register_buffer('num_updates_tracked',
65 torch.tensor(0, dtype=torch.long))
66 self.num_updates_tracked = self.num_updates_tracked.to('cuda')
67 # if self.remove_sn:
68 # # If we want to remove the spectral norm, we first copy the
69 # # weights to the moving average model.
70 # self.copy_s2t()
71 #
72 # def fn_remove_sn(m):
73 # r"""Remove spectral norm."""
74 # if hasattr(m, 'weight_orig'):
75 # remove_spectral_norm(m)
76 #
77 # self.averaged_model.apply(fn_remove_sn)
78 # self.dim = 0
79 if self.remove_wn_wrapper:
80 self.copy_s2t()
81
82 self.averaged_model.apply(remove_weight_norms)
83 self.dim = 0
84 else:
85 self.averaged_model.eval()
86
87 # Averaged model does not require grad.
88 requires_grad(self.averaged_model, False)
89
90 def forward(self, *inputs, **kwargs):
91 r"""PyTorch module forward function overload."""

Callers

nothing calls this directly

Calls 4

copy_s2tMethod · 0.95
requires_gradFunction · 0.90
applyMethod · 0.80
evalMethod · 0.80

Tested by

no test coverage detected