r"""run on 1 GPU
(gpus_per_worker, num_steps=1, batch_size=1, num_classes=8)
| 349 | |
| 350 | |
| 351 | def baseline_main(gpus_per_worker, num_steps=1, batch_size=1, num_classes=8): |
| 352 | r"""run on 1 GPU""" |
| 353 | emb_size = 3 |
| 354 | learning_rate = 0.1 |
| 355 | momentum = 0.9 |
| 356 | image_size = 6 |
| 357 | |
| 358 | zeros_init = sailfish.ZerosInitializer() |
| 359 | |
| 360 | # hybrid parallelism |
| 361 | torch.manual_seed(42) |
| 362 | random.seed(42) |
| 363 | fe = torch.nn.Linear(image_size, emb_size).cuda() |
| 364 | fc = sailfish.ArcFaceLinear( |
| 365 | emb_size, num_classes, weight_initializer=zeros_init).cuda() |
| 366 | fc_params = list(fc.parameters()) |
| 367 | criterion = torch.nn.CrossEntropyLoss().cuda() |
| 368 | optimizer = torch.optim.SGD( |
| 369 | fc.parameters(), lr=learning_rate, momentum=momentum) |
| 370 | fc.train() |
| 371 | criterion.train() |
| 372 | |
| 373 | results = [] |
| 374 | for step in range(num_steps): |
| 375 | result_item = {} |
| 376 | features_list = [] |
| 377 | label_list = [] |
| 378 | for gpu in range(gpus_per_worker): |
| 379 | torch.manual_seed(42 * step + gpu) |
| 380 | random.seed(42 * step + gpu) |
| 381 | features_list.append( |
| 382 | fe(torch.randn([batch_size, image_size]).cuda())) |
| 383 | label_list.append( |
| 384 | torch.as_tensor([ |
| 385 | random.randint(0, num_classes - 1) |
| 386 | for _ in range(batch_size) |
| 387 | ]).cuda()) |
| 388 | all_features = torch.cat(features_list) |
| 389 | all_label = torch.cat(label_list) |
| 390 | torch.manual_seed(42 * step) |
| 391 | random.seed(42 * step) |
| 392 | result_item['features/size'] = list(all_features.size()) |
| 393 | result_item['features/norm'] = torch.norm(all_features).item() |
| 394 | logits = fc(all_features, all_label) |
| 395 | loss = criterion(logits, all_label) |
| 396 | result_item['loss'] = loss.item() |
| 397 | optimizer.zero_grad() |
| 398 | loss.backward() |
| 399 | optimizer.step() |
| 400 | result_item['logits/grad/norm'] = torch.norm(fc_params[0].grad).item() |
| 401 | results.append(result_item) |
| 402 | return results |
| 403 | |
| 404 | |
| 405 | class TestArcFaceLinear(unittest.TestCase): |