(self, exp, args)
| 32 | |
| 33 | class Trainer: |
| 34 | def __init__(self, exp, args): |
| 35 | # init function only defines some basic attr, other attrs like model, optimizer are built in |
| 36 | # before_train methods. |
| 37 | self.exp = exp |
| 38 | self.args = args |
| 39 | |
| 40 | # training related attr |
| 41 | self.max_epoch = exp.max_epoch |
| 42 | self.amp_training = args.fp16 |
| 43 | self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16) |
| 44 | self.is_distributed = get_world_size() > 1 |
| 45 | self.rank = get_rank() |
| 46 | self.local_rank = args.local_rank |
| 47 | self.device = "cuda:{}".format(self.local_rank) |
| 48 | self.use_model_ema = exp.ema |
| 49 | |
| 50 | # data/dataloader related attr |
| 51 | self.data_type = torch.float16 if args.fp16 else torch.float32 |
| 52 | self.input_size = exp.input_size |
| 53 | self.best_ap = 0 |
| 54 | |
| 55 | # metric record |
| 56 | self.meter = MeterBuffer(window_size=exp.print_interval) |
| 57 | self.file_name = os.path.join(exp.output_dir, args.experiment_name) |
| 58 | |
| 59 | if self.rank == 0: |
| 60 | os.makedirs(self.file_name, exist_ok=True) |
| 61 | |
| 62 | setup_logger( |
| 63 | self.file_name, |
| 64 | distributed_rank=self.rank, |
| 65 | filename="train_log.txt", |
| 66 | mode="a", |
| 67 | ) |
| 68 | |
| 69 | def train(self): |
| 70 | self.before_train() |
nothing calls this directly
no test coverage detected