MCPcopy
hub / github.com/TencentARC/Pixal3D / init_models_and_more

Method init_models_and_more

pixal3d/trainers/basic.py:220–287  ·  view source on GitHub ↗

Initialize models and more.

(self, **kwargs)

Source from the content-addressed store, hash-verified

218 return next(list(self.models.values())[0].parameters()).device
219
220 def init_models_and_more(self, **kwargs):
221 """
222 Initialize models and more.
223 """
224 if self.world_size > 1:
225 # Prepare distributed data parallel
226 self.training_models = {
227 name: DDP(
228 model,
229 device_ids=[self.local_rank],
230 output_device=self.local_rank,
231 bucket_cap_mb=128,
232 find_unused_parameters=False
233 )
234 for name, model in self.models.items()
235 }
236 else:
237 self.training_models = self.models
238
239 # Build master params
240 self.model_params = sum(
241 [[p for p in model.parameters() if p.requires_grad] for model in self.models.values()]
242 , [])
243 if self.mix_precision_mode == 'amp':
244 self.master_params = self.model_params
245 if self.mix_precision_dtype == torch.float16:
246 self.scaler = torch.GradScaler()
247 elif self.mix_precision_mode == 'inflat_all':
248 self.master_params = make_master_params(self.model_params)
249 if self.mix_precision_dtype == torch.float16:
250 self.log_scale = 20.0
251 elif self.mix_precision_mode is None:
252 self.master_params = self.model_params
253 else:
254 raise NotImplementedError(f'Mix precision mode {self.mix_precision_mode} is not implemented.')
255
256 # Build EMA params
257 if self.is_master:
258 self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate]
259
260 # Initialize optimizer
261 if hasattr(torch.optim, self.optimizer_config['name']):
262 self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args'])
263 else:
264 self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args'])
265
266 # Initalize learning rate scheduler
267 if self.lr_scheduler_config is not None:
268 if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']):
269 self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args'])
270 else:
271 self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args'])
272
273 # Initialize elastic memory controller
274 if self.elastic_controller_config is not None:
275 assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \
276 'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
277 self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args'])

Callers 1

__init__Method · 0.95

Calls 2

make_master_paramsFunction · 0.85

Tested by

no test coverage detected