(
self,
*,
model,
use_fp16=False,
fp16_scale_growth=1e-3,
initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
)
| 147 | |
| 148 | class MixedPrecisionTrainer: |
| 149 | def __init__( |
| 150 | self, |
| 151 | *, |
| 152 | model, |
| 153 | use_fp16=False, |
| 154 | fp16_scale_growth=1e-3, |
| 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, |
| 156 | ): |
| 157 | self.model = model |
| 158 | self.use_fp16 = use_fp16 |
| 159 | self.fp16_scale_growth = fp16_scale_growth |
| 160 | |
| 161 | self.model_params = list(self.model.parameters()) |
| 162 | self.master_params = self.model_params |
| 163 | self.param_groups_and_shapes = None |
| 164 | self.lg_loss_scale = initial_lg_loss_scale |
| 165 | |
| 166 | if self.use_fp16: |
| 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( |
| 168 | self.model.named_parameters() |
| 169 | ) |
| 170 | self.master_params = make_master_params(self.param_groups_and_shapes) |
| 171 | self.model.convert_to_fp16() |
| 172 | |
| 173 | def zero_grad(self): |
| 174 | zero_grad(self.model_params) |
nothing calls this directly
no test coverage detected