| 168 | |
| 169 | |
| 170 | def set_model(opt): |
| 171 | model = SupCEResNet(name=opt.model, num_classes=opt.n_cls) |
| 172 | criterion = torch.nn.CrossEntropyLoss() |
| 173 | |
| 174 | # enable synchronized Batch Normalization |
| 175 | if opt.syncBN: |
| 176 | model = apex.parallel.convert_syncbn_model(model) |
| 177 | |
| 178 | if torch.cuda.is_available(): |
| 179 | if torch.cuda.device_count() > 1: |
| 180 | model = torch.nn.DataParallel(model) |
| 181 | model = model.cuda() |
| 182 | criterion = criterion.cuda() |
| 183 | cudnn.benchmark = True |
| 184 | |
| 185 | return model, criterion |
| 186 | |
| 187 | |
| 188 | def train(train_loader, model, criterion, optimizer, epoch, opt): |