MCPcopy
hub / github.com/microsoft/Swin-Transformer / build_model

Function build_model

models/build.py:15–121  ·  view source on GitHub ↗
(config, is_pretrain=False)

Source from the content-addressed store, hash-verified

13
14
15def build_model(config, is_pretrain=False):
16 model_type = config.MODEL.TYPE
17
18 # accelerate layernorm
19 if config.FUSED_LAYERNORM:
20 try:
21 import apex as amp
22 layernorm = amp.normalization.FusedLayerNorm
23 except:
24 layernorm = None
25 print("To use FusedLayerNorm, please install apex.")
26 else:
27 import torch.nn as nn
28 layernorm = nn.LayerNorm
29
30 if is_pretrain:
31 model = build_simmim(config)
32 return model
33
34 if model_type == 'swin':
35 model = SwinTransformer(img_size=config.DATA.IMG_SIZE,
36 patch_size=config.MODEL.SWIN.PATCH_SIZE,
37 in_chans=config.MODEL.SWIN.IN_CHANS,
38 num_classes=config.MODEL.NUM_CLASSES,
39 embed_dim=config.MODEL.SWIN.EMBED_DIM,
40 depths=config.MODEL.SWIN.DEPTHS,
41 num_heads=config.MODEL.SWIN.NUM_HEADS,
42 window_size=config.MODEL.SWIN.WINDOW_SIZE,
43 mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
44 qkv_bias=config.MODEL.SWIN.QKV_BIAS,
45 qk_scale=config.MODEL.SWIN.QK_SCALE,
46 drop_rate=config.MODEL.DROP_RATE,
47 drop_path_rate=config.MODEL.DROP_PATH_RATE,
48 ape=config.MODEL.SWIN.APE,
49 norm_layer=layernorm,
50 patch_norm=config.MODEL.SWIN.PATCH_NORM,
51 use_checkpoint=config.TRAIN.USE_CHECKPOINT,
52 fused_window_process=config.FUSED_WINDOW_PROCESS)
53 elif model_type == 'swinv2':
54 model = SwinTransformerV2(img_size=config.DATA.IMG_SIZE,
55 patch_size=config.MODEL.SWINV2.PATCH_SIZE,
56 in_chans=config.MODEL.SWINV2.IN_CHANS,
57 num_classes=config.MODEL.NUM_CLASSES,
58 embed_dim=config.MODEL.SWINV2.EMBED_DIM,
59 depths=config.MODEL.SWINV2.DEPTHS,
60 num_heads=config.MODEL.SWINV2.NUM_HEADS,
61 window_size=config.MODEL.SWINV2.WINDOW_SIZE,
62 mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
63 qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
64 drop_rate=config.MODEL.DROP_RATE,
65 drop_path_rate=config.MODEL.DROP_PATH_RATE,
66 ape=config.MODEL.SWINV2.APE,
67 patch_norm=config.MODEL.SWINV2.PATCH_NORM,
68 use_checkpoint=config.TRAIN.USE_CHECKPOINT,
69 pretrained_window_sizes=config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES)
70 elif model_type == 'swin_moe':
71 model = SwinTransformerMoE(img_size=config.DATA.IMG_SIZE,
72 patch_size=config.MODEL.SWIN_MOE.PATCH_SIZE,

Callers 4

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 5

build_simmimFunction · 0.85
SwinTransformerClass · 0.85
SwinTransformerV2Class · 0.85
SwinTransformerMoEClass · 0.85
SwinMLPClass · 0.85

Tested by

no test coverage detected