(self)
| 403 | |
| 404 | class BnB4BitTrainingTests(Base4bitTests): |
| 405 | def setUp(self): |
| 406 | gc.collect() |
| 407 | backend_empty_cache(torch_device) |
| 408 | |
| 409 | nf4_config = BitsAndBytesConfig( |
| 410 | load_in_4bit=True, |
| 411 | bnb_4bit_quant_type="nf4", |
| 412 | bnb_4bit_compute_dtype=torch.float16, |
| 413 | ) |
| 414 | self.model_4bit = SD3Transformer2DModel.from_pretrained( |
| 415 | self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device |
| 416 | ) |
| 417 | |
| 418 | def test_training(self): |
| 419 | # Step 1: freeze all parameters |
nothing calls this directly
no test coverage detected