MCPcopy
hub / github.com/jindongwang/transferlearning / train

Method train

code/deep/ReMoS/CV_backdoor/finetuner.py:115–251  ·  view source on GitHub ↗
(self, )

Source from the content-addressed store, hash-verified

113 return float(top1) / total * 100, total_ce / (i + 1)
114
115 def train(self, ):
116 model = self.model
117 train_loader = self.train_loader
118 test_loader = self.test_loader
119 iterations = self.args.iterations
120 lr = self.args.lr
121 output_dir = self.args.output_dir
122 teacher = self.teacher
123 args = self.args
124 model = model.to('cuda')
125
126
127 optimizer = optim.SGD(
128 model.parameters(),
129 lr=lr,
130 momentum=args.momentum,
131 weight_decay=args.weight_decay,
132 )
133
134 teacher.eval()
135 ce = CrossEntropyLabelSmooth(train_loader.dataset.num_classes)
136
137 batch_time = MovingAverageMeter('Time', ':6.3f')
138 data_time = MovingAverageMeter('Data', ':6.3f')
139 ce_loss_meter = MovingAverageMeter('CE Loss', ':6.3f')
140 top1_meter = MovingAverageMeter('Acc@1', ':6.2f')
141
142 train_path = osp.join(output_dir, "train.tsv")
143 with open(train_path, 'a') as wf:
144 columns = ['time', 'iter', 'Acc', 'celoss']
145 wf.write('\t'.join(columns) + '\n')
146 test_path = osp.join(output_dir, "test.tsv")
147 with open(test_path, 'a') as wf:
148 columns = ['time', 'iter', 'Acc', 'celoss']
149 wf.write('\t'.join(columns) + '\n')
150 adv_path = osp.join(output_dir, "adv.tsv")
151 with open(adv_path, 'a') as wf:
152 columns = ['time', 'iter', 'Acc', 'AdvAcc', 'ASR']
153 wf.write('\t'.join(columns) + '\n')
154
155 dataloader_iterator = iter(train_loader)
156 for i in range(iterations):
157 model.train()
158 optimizer.zero_grad()
159
160 end = time.time()
161 try:
162 batch, label = next(dataloader_iterator)
163 except:
164 dataloader_iterator = iter(train_loader)
165 batch, label = next(dataloader_iterator)
166 batch, label = batch.to('cuda'), label.to('cuda')
167 data_time.update(time.time() - end)
168
169 top1, ce_loss = self.compute_loss(
170 batch, label, ce
171 )
172 top1_meter.update(top1)

Callers 2

teacher_trainFunction · 0.95
finetune.pyFile · 0.45

Calls 14

updateMethod · 0.95
compute_lossMethod · 0.95
displayMethod · 0.95
testMethod · 0.95
stepMethod · 0.80
MovingAverageMeterClass · 0.70
ProgressMeterClass · 0.70
parametersMethod · 0.45
writeMethod · 0.45
backwardMethod · 0.45
saveMethod · 0.45

Tested by

no test coverage detected