(args, model)
| 20 | |
| 21 | |
| 22 | def build_optimizer(args, model): |
| 23 | def exclude( |
| 24 | n, p): return p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n |
| 25 | |
| 26 | def include(n, p): return not exclude(n, p) |
| 27 | |
| 28 | named_parameters = list(model.named_parameters()) |
| 29 | # we create three optimizer for image encode, text encoder, and jointly part |
| 30 | model_parts = [ |
| 31 | list(model.image_named_params()), |
| 32 | list(model.text_named_params()), |
| 33 | list(model.joint_named_params()), |
| 34 | ] |
| 35 | |
| 36 | cnt1 = sum(v.numel() for k, v in named_parameters if v.requires_grad) |
| 37 | cnt2 = sum(sum(v.numel() for k, v in part if v.requires_grad) |
| 38 | for part in model_parts) |
| 39 | assert cnt1 == cnt2, f"cnt1 {cnt1} != cnt2 {cnt2}" |
| 40 | |
| 41 | optimizer = [] |
| 42 | part_names = ['image', 'text', 'joint'] |
| 43 | assert len(model_parts) == len(part_names) |
| 44 | for name, named_parameters in zip(part_names, model_parts): |
| 45 | gain_or_bias_params = [p for n, p in named_parameters if exclude( |
| 46 | n, p) and p.requires_grad and "l0_module" not in n] |
| 47 | rest_params = [p for n, p in named_parameters if include( |
| 48 | n, p) and p.requires_grad and "l0_module" not in n] |
| 49 | params_groups = [ |
| 50 | {"params": gain_or_bias_params, "weight_decay": 0.}, |
| 51 | {"params": rest_params, "weight_decay": args.wd}, |
| 52 | ] |
| 53 | |
| 54 | num_opt_params = 0 |
| 55 | for pg in params_groups: |
| 56 | num_opt_params += sum(p.numel() for p in pg['params']) |
| 57 | |
| 58 | logging.info(f'number of optimizer ({name}) params: {num_opt_params}') |
| 59 | |
| 60 | if num_opt_params > 0: |
| 61 | optimizer_i = optim.AdamW( |
| 62 | params_groups, |
| 63 | lr=args.lr, |
| 64 | betas=(args.beta1, args.beta2), |
| 65 | eps=args.eps, |
| 66 | ) |
| 67 | else: |
| 68 | optimizer_i = EmptyOptimizer() |
| 69 | optimizer.append(optimizer_i) |
| 70 | |
| 71 | if args.prune_image or args.prune_text: |
| 72 | lr_l0 = 0.02 |
| 73 | lr_lamda = args.l0lr |
| 74 | l0_params = [] |
| 75 | # add l0 optimizer |
| 76 | if args.prune_image: |
| 77 | l0_params.extend([ |
| 78 | { |
| 79 | "params": [p for n, p in model.image_named_params() if p.requires_grad and "lambda" not in n and "l0_module" in n], |
no test coverage detected