()
| 40 | |
| 41 | |
| 42 | def main(): |
| 43 | # launch from torch |
| 44 | parser = colossalai.legacy.get_default_parser() |
| 45 | args = parser.parse_args() |
| 46 | colossalai.legacy.launch_from_torch(config=args.config) |
| 47 | |
| 48 | # get logger |
| 49 | logger = get_dist_logger() |
| 50 | logger.info("initialized distributed environment", ranks=[0]) |
| 51 | |
| 52 | if hasattr(gpc.config, "LOG_PATH"): |
| 53 | if gpc.get_global_rank() == 0: |
| 54 | log_path = gpc.config.LOG_PATH |
| 55 | if not os.path.exists(log_path): |
| 56 | os.mkdir(log_path) |
| 57 | logger.log_to_file(log_path) |
| 58 | |
| 59 | use_pipeline = is_using_pp() |
| 60 | |
| 61 | # create model |
| 62 | model_kwargs = dict( |
| 63 | img_size=gpc.config.IMG_SIZE, |
| 64 | patch_size=gpc.config.PATCH_SIZE, |
| 65 | hidden_size=gpc.config.HIDDEN_SIZE, |
| 66 | depth=gpc.config.DEPTH, |
| 67 | num_heads=gpc.config.NUM_HEADS, |
| 68 | mlp_ratio=gpc.config.MLP_RATIO, |
| 69 | num_classes=10, |
| 70 | init_method="jax", |
| 71 | checkpoint=gpc.config.CHECKPOINT, |
| 72 | ) |
| 73 | |
| 74 | if use_pipeline: |
| 75 | pipelinable = PipelinableContext() |
| 76 | with pipelinable: |
| 77 | model = _create_vit_model(**model_kwargs) |
| 78 | pipelinable.to_layer_list() |
| 79 | pipelinable.policy = "uniform" |
| 80 | model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) |
| 81 | else: |
| 82 | model = _create_vit_model(**model_kwargs) |
| 83 | |
| 84 | # count number of parameters |
| 85 | total_numel = 0 |
| 86 | for p in model.parameters(): |
| 87 | total_numel += p.numel() |
| 88 | if not gpc.is_initialized(ParallelMode.PIPELINE): |
| 89 | pipeline_stage = 0 |
| 90 | else: |
| 91 | pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) |
| 92 | logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") |
| 93 | |
| 94 | # use synthetic dataset |
| 95 | # we train for 10 steps and eval for 5 steps per epoch |
| 96 | train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE) |
| 97 | test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE) |
| 98 | |
| 99 | # create loss function |
no test coverage detected
searching dependent graphs…