| 7 | |
| 8 | |
| 9 | def build_model(config): |
| 10 | model_type = config.MODEL.TYPE |
| 11 | if model_type == 'tiny_vit': |
| 12 | M = config.MODEL.TINY_VIT |
| 13 | model = TinyViT(img_size=config.DATA.IMG_SIZE, |
| 14 | in_chans=M.IN_CHANS, |
| 15 | num_classes=config.MODEL.NUM_CLASSES, |
| 16 | embed_dims=M.EMBED_DIMS, |
| 17 | depths=M.DEPTHS, |
| 18 | num_heads=M.NUM_HEADS, |
| 19 | window_sizes=M.WINDOW_SIZES, |
| 20 | mlp_ratio=M.MLP_RATIO, |
| 21 | drop_rate=config.MODEL.DROP_RATE, |
| 22 | drop_path_rate=config.MODEL.DROP_PATH_RATE, |
| 23 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, |
| 24 | mbconv_expand_ratio=M.MBCONV_EXPAND_RATIO, |
| 25 | local_conv_size=M.LOCAL_CONV_SIZE, |
| 26 | layer_lr_decay=config.TRAIN.LAYER_LR_DECAY, |
| 27 | ) |
| 28 | elif model_type == 'clip_vit_large14_224': |
| 29 | from .clip import CLIP |
| 30 | kwargs = { |
| 31 | 'embed_dim': 768, 'image_resolution': 224, |
| 32 | 'vision_layers': 24, 'vision_width': 1024, 'vision_patch_size': 14, |
| 33 | "num_classes": config.MODEL.NUM_CLASSES, |
| 34 | } |
| 35 | model = CLIP(**kwargs) |
| 36 | else: |
| 37 | raise NotImplementedError(f"Unkown model: {model_type}") |
| 38 | |
| 39 | return model |