Initialize models and more.
(self, **kwargs)
| 218 | return next(list(self.models.values())[0].parameters()).device |
| 219 | |
| 220 | def init_models_and_more(self, **kwargs): |
| 221 | """ |
| 222 | Initialize models and more. |
| 223 | """ |
| 224 | if self.world_size > 1: |
| 225 | # Prepare distributed data parallel |
| 226 | self.training_models = { |
| 227 | name: DDP( |
| 228 | model, |
| 229 | device_ids=[self.local_rank], |
| 230 | output_device=self.local_rank, |
| 231 | bucket_cap_mb=128, |
| 232 | find_unused_parameters=False |
| 233 | ) |
| 234 | for name, model in self.models.items() |
| 235 | } |
| 236 | else: |
| 237 | self.training_models = self.models |
| 238 | |
| 239 | # Build master params |
| 240 | self.model_params = sum( |
| 241 | [[p for p in model.parameters() if p.requires_grad] for model in self.models.values()] |
| 242 | , []) |
| 243 | if self.mix_precision_mode == 'amp': |
| 244 | self.master_params = self.model_params |
| 245 | if self.mix_precision_dtype == torch.float16: |
| 246 | self.scaler = torch.GradScaler() |
| 247 | elif self.mix_precision_mode == 'inflat_all': |
| 248 | self.master_params = make_master_params(self.model_params) |
| 249 | if self.mix_precision_dtype == torch.float16: |
| 250 | self.log_scale = 20.0 |
| 251 | elif self.mix_precision_mode is None: |
| 252 | self.master_params = self.model_params |
| 253 | else: |
| 254 | raise NotImplementedError(f'Mix precision mode {self.mix_precision_mode} is not implemented.') |
| 255 | |
| 256 | # Build EMA params |
| 257 | if self.is_master: |
| 258 | self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate] |
| 259 | |
| 260 | # Initialize optimizer |
| 261 | if hasattr(torch.optim, self.optimizer_config['name']): |
| 262 | self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args']) |
| 263 | else: |
| 264 | self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args']) |
| 265 | |
| 266 | # Initalize learning rate scheduler |
| 267 | if self.lr_scheduler_config is not None: |
| 268 | if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']): |
| 269 | self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args']) |
| 270 | else: |
| 271 | self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args']) |
| 272 | |
| 273 | # Initialize elastic memory controller |
| 274 | if self.elastic_controller_config is not None: |
| 275 | assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \ |
| 276 | 'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin' |
| 277 | self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args']) |
no test coverage detected