(
self,
*,
scaling_config: Optional[ScalingConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
metadata: Optional[Dict[str, Any]] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
)
| 231 | _fields_for_tuner_param_space = [] |
| 232 | |
| 233 | def __init__( |
| 234 | self, |
| 235 | *, |
| 236 | scaling_config: Optional[ScalingConfig] = None, |
| 237 | run_config: Optional[RunConfig] = None, |
| 238 | datasets: Optional[Dict[str, GenDataset]] = None, |
| 239 | metadata: Optional[Dict[str, Any]] = None, |
| 240 | resume_from_checkpoint: Optional[Checkpoint] = None, |
| 241 | ): |
| 242 | self.scaling_config = ( |
| 243 | scaling_config if scaling_config is not None else ScalingConfig() |
| 244 | ) |
| 245 | self.run_config = ( |
| 246 | copy.copy(run_config) if run_config is not None else RunConfig() |
| 247 | ) |
| 248 | self.metadata = metadata |
| 249 | self.datasets = datasets if datasets is not None else {} |
| 250 | self.starting_checkpoint = resume_from_checkpoint |
| 251 | |
| 252 | if _v2_migration_warnings_enabled(): |
| 253 | if metadata is not None: |
| 254 | _log_deprecation_warning(_GET_METADATA_DEPRECATION_MESSAGE) |
| 255 | if resume_from_checkpoint is not None: |
| 256 | _log_deprecation_warning(_RESUME_FROM_CHECKPOINT_DEPRECATION_WARNING) |
| 257 | |
| 258 | # These attributes should only be set through `BaseTrainer.restore` |
| 259 | self._restore_path = None |
| 260 | self._restore_storage_filesystem = None |
| 261 | |
| 262 | self._validate_attributes() |
| 263 | |
| 264 | usage_lib.record_library_usage("train") |
| 265 | air_usage.tag_air_trainer(self) |
| 266 | |
| 267 | @classmethod |
| 268 | @Deprecated(message=_TRAINER_RESTORE_DEPRECATION_WARNING) |
nothing calls this directly
no test coverage detected