MCPcopy
hub / github.com/pytorch/vision / test_train_eval

Method test_train_eval

test/test_backbone_utils.py:227–292  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

225 out_agg.backward()
226
227 def test_train_eval(self):
228 class TestModel(torch.nn.Module):
229 def __init__(self):
230 super().__init__()
231 self.dropout = torch.nn.Dropout(p=1.0)
232
233 def forward(self, x):
234 x = x.float().mean()
235 x = self.dropout(x) # dropout
236 if self.training:
237 x += 100 # add
238 else:
239 x *= 0 # mul
240 x -= 0 # sub
241 return x
242
243 model = TestModel()
244
245 train_return_nodes = ["dropout", "add", "sub"]
246 eval_return_nodes = ["dropout", "mul", "sub"]
247
248 def checks(model, mode):
249 with torch.no_grad():
250 out = model(torch.ones(10, 10))
251 if mode == "train":
252 # Check that dropout is respected
253 assert out["dropout"].item() == 0
254 # Check that control flow dependent on training_mode is respected
255 assert out["sub"].item() == 100
256 assert "add" in out
257 assert "mul" not in out
258 elif mode == "eval":
259 # Check that dropout is respected
260 assert out["dropout"].item() == 1
261 # Check that control flow dependent on training_mode is respected
262 assert out["sub"].item() == 0
263 assert "mul" in out
264 assert "add" not in out
265
266 # Starting from train mode
267 model.train()
268 fx_model = self._create_feature_extractor(
269 model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
270 )
271 # Check that the models stay in their original training state
272 assert model.training
273 assert fx_model.training
274 # Check outputs
275 checks(fx_model, "train")
276 # Check outputs after switching to eval mode
277 fx_model.eval()
278 checks(fx_model, "eval")
279
280 # Starting from eval mode
281 model.eval()
282 fx_model = self._create_feature_extractor(
283 model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
284 )

Callers

nothing calls this directly

Calls 3

TestModelClass · 0.85
trainMethod · 0.80

Tested by

no test coverage detected