(self)
| 225 | out_agg.backward() |
| 226 | |
| 227 | def test_train_eval(self): |
| 228 | class TestModel(torch.nn.Module): |
| 229 | def __init__(self): |
| 230 | super().__init__() |
| 231 | self.dropout = torch.nn.Dropout(p=1.0) |
| 232 | |
| 233 | def forward(self, x): |
| 234 | x = x.float().mean() |
| 235 | x = self.dropout(x) # dropout |
| 236 | if self.training: |
| 237 | x += 100 # add |
| 238 | else: |
| 239 | x *= 0 # mul |
| 240 | x -= 0 # sub |
| 241 | return x |
| 242 | |
| 243 | model = TestModel() |
| 244 | |
| 245 | train_return_nodes = ["dropout", "add", "sub"] |
| 246 | eval_return_nodes = ["dropout", "mul", "sub"] |
| 247 | |
| 248 | def checks(model, mode): |
| 249 | with torch.no_grad(): |
| 250 | out = model(torch.ones(10, 10)) |
| 251 | if mode == "train": |
| 252 | # Check that dropout is respected |
| 253 | assert out["dropout"].item() == 0 |
| 254 | # Check that control flow dependent on training_mode is respected |
| 255 | assert out["sub"].item() == 100 |
| 256 | assert "add" in out |
| 257 | assert "mul" not in out |
| 258 | elif mode == "eval": |
| 259 | # Check that dropout is respected |
| 260 | assert out["dropout"].item() == 1 |
| 261 | # Check that control flow dependent on training_mode is respected |
| 262 | assert out["sub"].item() == 0 |
| 263 | assert "mul" in out |
| 264 | assert "add" not in out |
| 265 | |
| 266 | # Starting from train mode |
| 267 | model.train() |
| 268 | fx_model = self._create_feature_extractor( |
| 269 | model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes |
| 270 | ) |
| 271 | # Check that the models stay in their original training state |
| 272 | assert model.training |
| 273 | assert fx_model.training |
| 274 | # Check outputs |
| 275 | checks(fx_model, "train") |
| 276 | # Check outputs after switching to eval mode |
| 277 | fx_model.eval() |
| 278 | checks(fx_model, "eval") |
| 279 | |
| 280 | # Starting from eval mode |
| 281 | model.eval() |
| 282 | fx_model = self._create_feature_extractor( |
| 283 | model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes |
| 284 | ) |
nothing calls this directly
no test coverage detected