(self, mode=True)
| 253 | |
| 254 | @torch.no_grad() |
| 255 | def train(self, mode=True): |
| 256 | super().train(mode) |
| 257 | if mode and hasattr(self, 'ab'): |
| 258 | del self.ab |
| 259 | else: |
| 260 | self.ab = self.attention_biases[:, self.attention_bias_idxs] |
| 261 | |
| 262 | def forward(self, x): # x (B,N,C) |
| 263 | B, N, _ = x.shape |
no outgoing calls
no test coverage detected