Repeat.
(self, *args)
| 25 | self.layer_drop_rate = layer_drop_rate |
| 26 | |
| 27 | def forward(self, *args): |
| 28 | """Repeat.""" |
| 29 | _probs = torch.empty(len(self)).uniform_() |
| 30 | for idx, m in enumerate(self): |
| 31 | if not self.training or (_probs[idx] >= self.layer_drop_rate): |
| 32 | args = m(*args) |
| 33 | return args |
| 34 | |
| 35 | |
| 36 | def repeat(N, fn, layer_drop_rate=0.0): |