MCPcopy
hub / github.com/TencentARC/Pixal3D / __init__

Method __init__

pixal3d/trainers/basic.py:65–180  ·  view source on GitHub ↗
(self,
        models,
        dataset,
        *,
        output_dir,
        load_dir,
        step,
        max_steps,
        batch_size=None,
        batch_size_per_gpu=None,
        batch_split=None,
        optimizer={},
        lr_scheduler=None,
        elastic=None,
        grad_clip=None,
        ema_rate=0.9999,
        fp16_mode=None,
        mix_precision_mode='inflat_all',
        mix_precision_dtype='float16',
        fp16_scale_growth=1e-3,
        parallel_mode='ddp',
        finetune_ckpt=None,
        log_param_stats=False,
        prefetch_data=True,
        snapshot_batch_size=4,
        snapshot_num_samples=64,
        num_workers=None,
        debug=False,
        i_print=1000,
        i_log=500,
        i_sample=10000,
        i_save=10000,
        i_ddpcheck=10000,
        wandb_run=None,  # wandb run object
        **kwargs
    )

Source from the content-addressed store, hash-verified

63 i_ddpcheck (int): DDP check interval.
64 """
65 def __init__(self,
66 models,
67 dataset,
68 *,
69 output_dir,
70 load_dir,
71 step,
72 max_steps,
73 batch_size=None,
74 batch_size_per_gpu=None,
75 batch_split=None,
76 optimizer={},
77 lr_scheduler=None,
78 elastic=None,
79 grad_clip=None,
80 ema_rate=0.9999,
81 fp16_mode=None,
82 mix_precision_mode='inflat_all',
83 mix_precision_dtype='float16',
84 fp16_scale_growth=1e-3,
85 parallel_mode='ddp',
86 finetune_ckpt=None,
87 log_param_stats=False,
88 prefetch_data=True,
89 snapshot_batch_size=4,
90 snapshot_num_samples=64,
91 num_workers=None,
92 debug=False,
93 i_print=1000,
94 i_log=500,
95 i_sample=10000,
96 i_save=10000,
97 i_ddpcheck=10000,
98 wandb_run=None, # wandb run object
99 **kwargs
100 ):
101 assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
102
103 self.models = models
104 self.dataset = dataset
105 self.batch_split = batch_split if batch_split is not None else 1
106 self.max_steps = max_steps
107 self.debug = debug
108 self.optimizer_config = optimizer
109 self.lr_scheduler_config = lr_scheduler
110 self.elastic_controller_config = elastic
111 self.grad_clip = grad_clip
112 self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
113 if fp16_mode is not None:
114 mix_precision_dtype = 'float16'
115 mix_precision_mode = fp16_mode
116 self.mix_precision_mode = mix_precision_mode
117 self.mix_precision_dtype = str_to_dtype(mix_precision_dtype)
118 self.fp16_scale_growth = fp16_scale_growth
119 self.parallel_mode = parallel_mode
120 self.log_param_stats = log_param_stats
121 self.prefetch_data = prefetch_data
122 self.snapshot_batch_size = snapshot_batch_size

Callers

nothing calls this directly

Calls 6

init_models_and_moreMethod · 0.95
prepare_dataloaderMethod · 0.95
loadMethod · 0.95
finetune_fromMethod · 0.95
check_ddpMethod · 0.95
str_to_dtypeFunction · 0.70

Tested by

no test coverage detected