(args, model, has_apex, filter_bias_and_bn=True)
| 94 | |
| 95 | |
| 96 | def create_optimizer_supernet(args, model, has_apex, filter_bias_and_bn=True): |
| 97 | opt_lower = args.opt.lower() |
| 98 | weight_decay = args.weight_decay |
| 99 | if 'adamw' in opt_lower or 'radam' in opt_lower: |
| 100 | weight_decay /= args.lr |
| 101 | if weight_decay and filter_bias_and_bn: |
| 102 | parameters = add_weight_decay_supernet(model, args, weight_decay) |
| 103 | weight_decay = 0. |
| 104 | else: |
| 105 | parameters = model.parameters() |
| 106 | |
| 107 | if 'fused' in opt_lower: |
| 108 | assert has_apex and torch.cuda.is_available( |
| 109 | ), 'APEX and CUDA required for fused optimizers' |
| 110 | |
| 111 | opt_split = opt_lower.split('_') |
| 112 | opt_lower = opt_split[-1] |
| 113 | if opt_lower == 'sgd' or opt_lower == 'nesterov': |
| 114 | optimizer = optim.SGD( |
| 115 | parameters, |
| 116 | momentum=args.momentum, |
| 117 | weight_decay=weight_decay, |
| 118 | nesterov=True) |
| 119 | elif opt_lower == 'momentum': |
| 120 | optimizer = optim.SGD( |
| 121 | parameters, |
| 122 | momentum=args.momentum, |
| 123 | weight_decay=weight_decay, |
| 124 | nesterov=False) |
| 125 | elif opt_lower == 'adam': |
| 126 | optimizer = optim.Adam( |
| 127 | parameters, weight_decay=weight_decay, eps=args.opt_eps) |
| 128 | else: |
| 129 | assert False and "Invalid optimizer" |
| 130 | raise ValueError |
| 131 | |
| 132 | return optimizer |
| 133 | |
| 134 | |
| 135 | def convert_lowercase(cfg): |
no test coverage detected