(self, )
| 113 | return float(top1) / total * 100, total_ce / (i + 1) |
| 114 | |
| 115 | def train(self, ): |
| 116 | model = self.model |
| 117 | train_loader = self.train_loader |
| 118 | test_loader = self.test_loader |
| 119 | iterations = self.args.iterations |
| 120 | lr = self.args.lr |
| 121 | output_dir = self.args.output_dir |
| 122 | teacher = self.teacher |
| 123 | args = self.args |
| 124 | model = model.to('cuda') |
| 125 | |
| 126 | |
| 127 | optimizer = optim.SGD( |
| 128 | model.parameters(), |
| 129 | lr=lr, |
| 130 | momentum=args.momentum, |
| 131 | weight_decay=args.weight_decay, |
| 132 | ) |
| 133 | |
| 134 | teacher.eval() |
| 135 | ce = CrossEntropyLabelSmooth(train_loader.dataset.num_classes) |
| 136 | |
| 137 | batch_time = MovingAverageMeter('Time', ':6.3f') |
| 138 | data_time = MovingAverageMeter('Data', ':6.3f') |
| 139 | ce_loss_meter = MovingAverageMeter('CE Loss', ':6.3f') |
| 140 | top1_meter = MovingAverageMeter('Acc@1', ':6.2f') |
| 141 | |
| 142 | train_path = osp.join(output_dir, "train.tsv") |
| 143 | with open(train_path, 'a') as wf: |
| 144 | columns = ['time', 'iter', 'Acc', 'celoss'] |
| 145 | wf.write('\t'.join(columns) + '\n') |
| 146 | test_path = osp.join(output_dir, "test.tsv") |
| 147 | with open(test_path, 'a') as wf: |
| 148 | columns = ['time', 'iter', 'Acc', 'celoss'] |
| 149 | wf.write('\t'.join(columns) + '\n') |
| 150 | adv_path = osp.join(output_dir, "adv.tsv") |
| 151 | with open(adv_path, 'a') as wf: |
| 152 | columns = ['time', 'iter', 'Acc', 'AdvAcc', 'ASR'] |
| 153 | wf.write('\t'.join(columns) + '\n') |
| 154 | |
| 155 | dataloader_iterator = iter(train_loader) |
| 156 | for i in range(iterations): |
| 157 | model.train() |
| 158 | optimizer.zero_grad() |
| 159 | |
| 160 | end = time.time() |
| 161 | try: |
| 162 | batch, label = next(dataloader_iterator) |
| 163 | except: |
| 164 | dataloader_iterator = iter(train_loader) |
| 165 | batch, label = next(dataloader_iterator) |
| 166 | batch, label = batch.to('cuda'), label.to('cuda') |
| 167 | data_time.update(time.time() - end) |
| 168 | |
| 169 | top1, ce_loss = self.compute_loss( |
| 170 | batch, label, ce |
| 171 | ) |
| 172 | top1_meter.update(top1) |
no test coverage detected