(self, resolution=32)
| 28 | |
| 29 | class TrainingTests(unittest.TestCase): |
| 30 | def get_model_optimizer(self, resolution=32): |
| 31 | set_seed(0) |
| 32 | model = UNet2DModel(sample_size=resolution, in_channels=3, out_channels=3) |
| 33 | optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) |
| 34 | return model, optimizer |
| 35 | |
| 36 | @slow |
| 37 | def test_training_step_equality(self): |
no test coverage detected