Main. Args: **kwargs: Additional keyword arguments.
(**kwargs)
| 59 | |
| 60 | |
| 61 | def main(**kwargs): |
| 62 | |
| 63 | # set random seed |
| 64 | """Main. |
| 65 | |
| 66 | Args: |
| 67 | **kwargs: Additional keyword arguments. |
| 68 | """ |
| 69 | set_all_random_seed(kwargs.get("seed", 0)) |
| 70 | torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) |
| 71 | torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) |
| 72 | torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) |
| 73 | # open tf32 |
| 74 | torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True) |
| 75 | |
| 76 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| 77 | if local_rank == 0: |
| 78 | tables.print() |
| 79 | # Check if we are using DDP or FSDP |
| 80 | use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1 |
| 81 | use_fsdp = kwargs.get("use_fsdp", False) |
| 82 | # use_ddp = False if use_fsdp else use_fsdp |
| 83 | if use_ddp or use_fsdp: |
| 84 | dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://") |
| 85 | torch.cuda.set_device(local_rank) |
| 86 | |
| 87 | logging.info("Build model, frontend, tokenizer") |
| 88 | device = kwargs.get("device", "cuda") |
| 89 | kwargs["device"] = "cpu" |
| 90 | model = AutoModel(**kwargs) |
| 91 | |
| 92 | # save config.yaml |
| 93 | if ( |
| 94 | (use_ddp or use_fsdp) |
| 95 | and dist.get_rank() == 0 |
| 96 | or not (use_ddp or use_fsdp) |
| 97 | and local_rank == 0 |
| 98 | ): |
| 99 | prepare_model_dir(**kwargs) |
| 100 | |
| 101 | # parse kwargs |
| 102 | kwargs = model.kwargs |
| 103 | kwargs["device"] = device |
| 104 | tokenizer = kwargs["tokenizer"] |
| 105 | frontend = kwargs["frontend"] |
| 106 | model = model.model |
| 107 | del kwargs["model"] |
| 108 | |
| 109 | # freeze_param |
| 110 | freeze_param = kwargs.get("freeze_param", None) |
| 111 | if freeze_param is not None: |
| 112 | if "," in freeze_param: |
| 113 | freeze_param = freeze_param.split(",") |
| 114 | if not isinstance(freeze_param, (list, tuple)): |
| 115 | freeze_param = (freeze_param,) |
| 116 | logging.info("freeze_param is not None: %s", freeze_param) |
| 117 | for t in freeze_param: |
| 118 | for k, p in model.named_parameters(): |
no test coverage detected
searching dependent graphs…