(self,
models,
dataset,
*,
output_dir,
load_dir,
step,
max_steps,
batch_size=None,
batch_size_per_gpu=None,
batch_split=None,
optimizer={},
lr_scheduler=None,
elastic=None,
grad_clip=None,
ema_rate=0.9999,
fp16_mode=None,
mix_precision_mode='inflat_all',
mix_precision_dtype='float16',
fp16_scale_growth=1e-3,
parallel_mode='ddp',
finetune_ckpt=None,
log_param_stats=False,
prefetch_data=True,
snapshot_batch_size=4,
snapshot_num_samples=64,
num_workers=None,
debug=False,
i_print=1000,
i_log=500,
i_sample=10000,
i_save=10000,
i_ddpcheck=10000,
wandb_run=None, # wandb run object
**kwargs
)
| 63 | i_ddpcheck (int): DDP check interval. |
| 64 | """ |
| 65 | def __init__(self, |
| 66 | models, |
| 67 | dataset, |
| 68 | *, |
| 69 | output_dir, |
| 70 | load_dir, |
| 71 | step, |
| 72 | max_steps, |
| 73 | batch_size=None, |
| 74 | batch_size_per_gpu=None, |
| 75 | batch_split=None, |
| 76 | optimizer={}, |
| 77 | lr_scheduler=None, |
| 78 | elastic=None, |
| 79 | grad_clip=None, |
| 80 | ema_rate=0.9999, |
| 81 | fp16_mode=None, |
| 82 | mix_precision_mode='inflat_all', |
| 83 | mix_precision_dtype='float16', |
| 84 | fp16_scale_growth=1e-3, |
| 85 | parallel_mode='ddp', |
| 86 | finetune_ckpt=None, |
| 87 | log_param_stats=False, |
| 88 | prefetch_data=True, |
| 89 | snapshot_batch_size=4, |
| 90 | snapshot_num_samples=64, |
| 91 | num_workers=None, |
| 92 | debug=False, |
| 93 | i_print=1000, |
| 94 | i_log=500, |
| 95 | i_sample=10000, |
| 96 | i_save=10000, |
| 97 | i_ddpcheck=10000, |
| 98 | wandb_run=None, # wandb run object |
| 99 | **kwargs |
| 100 | ): |
| 101 | assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.' |
| 102 | |
| 103 | self.models = models |
| 104 | self.dataset = dataset |
| 105 | self.batch_split = batch_split if batch_split is not None else 1 |
| 106 | self.max_steps = max_steps |
| 107 | self.debug = debug |
| 108 | self.optimizer_config = optimizer |
| 109 | self.lr_scheduler_config = lr_scheduler |
| 110 | self.elastic_controller_config = elastic |
| 111 | self.grad_clip = grad_clip |
| 112 | self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate |
| 113 | if fp16_mode is not None: |
| 114 | mix_precision_dtype = 'float16' |
| 115 | mix_precision_mode = fp16_mode |
| 116 | self.mix_precision_mode = mix_precision_mode |
| 117 | self.mix_precision_dtype = str_to_dtype(mix_precision_dtype) |
| 118 | self.fp16_scale_growth = fp16_scale_growth |
| 119 | self.parallel_mode = parallel_mode |
| 120 | self.log_param_stats = log_param_stats |
| 121 | self.prefetch_data = prefetch_data |
| 122 | self.snapshot_batch_size = snapshot_batch_size |
nothing calls this directly
no test coverage detected