(config, is_pretrain=False)
| 13 | |
| 14 | |
| 15 | def build_model(config, is_pretrain=False): |
| 16 | model_type = config.MODEL.TYPE |
| 17 | |
| 18 | # accelerate layernorm |
| 19 | if config.FUSED_LAYERNORM: |
| 20 | try: |
| 21 | import apex as amp |
| 22 | layernorm = amp.normalization.FusedLayerNorm |
| 23 | except: |
| 24 | layernorm = None |
| 25 | print("To use FusedLayerNorm, please install apex.") |
| 26 | else: |
| 27 | import torch.nn as nn |
| 28 | layernorm = nn.LayerNorm |
| 29 | |
| 30 | if is_pretrain: |
| 31 | model = build_simmim(config) |
| 32 | return model |
| 33 | |
| 34 | if model_type == 'swin': |
| 35 | model = SwinTransformer(img_size=config.DATA.IMG_SIZE, |
| 36 | patch_size=config.MODEL.SWIN.PATCH_SIZE, |
| 37 | in_chans=config.MODEL.SWIN.IN_CHANS, |
| 38 | num_classes=config.MODEL.NUM_CLASSES, |
| 39 | embed_dim=config.MODEL.SWIN.EMBED_DIM, |
| 40 | depths=config.MODEL.SWIN.DEPTHS, |
| 41 | num_heads=config.MODEL.SWIN.NUM_HEADS, |
| 42 | window_size=config.MODEL.SWIN.WINDOW_SIZE, |
| 43 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, |
| 44 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, |
| 45 | qk_scale=config.MODEL.SWIN.QK_SCALE, |
| 46 | drop_rate=config.MODEL.DROP_RATE, |
| 47 | drop_path_rate=config.MODEL.DROP_PATH_RATE, |
| 48 | ape=config.MODEL.SWIN.APE, |
| 49 | norm_layer=layernorm, |
| 50 | patch_norm=config.MODEL.SWIN.PATCH_NORM, |
| 51 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, |
| 52 | fused_window_process=config.FUSED_WINDOW_PROCESS) |
| 53 | elif model_type == 'swinv2': |
| 54 | model = SwinTransformerV2(img_size=config.DATA.IMG_SIZE, |
| 55 | patch_size=config.MODEL.SWINV2.PATCH_SIZE, |
| 56 | in_chans=config.MODEL.SWINV2.IN_CHANS, |
| 57 | num_classes=config.MODEL.NUM_CLASSES, |
| 58 | embed_dim=config.MODEL.SWINV2.EMBED_DIM, |
| 59 | depths=config.MODEL.SWINV2.DEPTHS, |
| 60 | num_heads=config.MODEL.SWINV2.NUM_HEADS, |
| 61 | window_size=config.MODEL.SWINV2.WINDOW_SIZE, |
| 62 | mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, |
| 63 | qkv_bias=config.MODEL.SWINV2.QKV_BIAS, |
| 64 | drop_rate=config.MODEL.DROP_RATE, |
| 65 | drop_path_rate=config.MODEL.DROP_PATH_RATE, |
| 66 | ape=config.MODEL.SWINV2.APE, |
| 67 | patch_norm=config.MODEL.SWINV2.PATCH_NORM, |
| 68 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, |
| 69 | pretrained_window_sizes=config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES) |
| 70 | elif model_type == 'swin_moe': |
| 71 | model = SwinTransformerMoE(img_size=config.DATA.IMG_SIZE, |
| 72 | patch_size=config.MODEL.SWIN_MOE.PATCH_SIZE, |
no test coverage detected