MCPcopy
hub / github.com/microsoft/Cream / TinyViT

Class TinyViT

TinyViT/models/tiny_vit.py:453–591  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

451
452
453class TinyViT(nn.Module):
454 def __init__(self, img_size=224, in_chans=3, num_classes=1000,
455 embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
456 num_heads=[3, 6, 12, 24],
457 window_sizes=[7, 7, 14, 7],
458 mlp_ratio=4.,
459 drop_rate=0.,
460 drop_path_rate=0.1,
461 use_checkpoint=False,
462 mbconv_expand_ratio=4.0,
463 local_conv_size=3,
464 layer_lr_decay=1.0,
465 ):
466 super().__init__()
467
468 self.num_classes = num_classes
469 self.depths = depths
470 self.num_layers = len(depths)
471 self.mlp_ratio = mlp_ratio
472
473 activation = nn.GELU
474
475 self.patch_embed = PatchEmbed(in_chans=in_chans,
476 embed_dim=embed_dims[0],
477 resolution=img_size,
478 activation=activation)
479
480 patches_resolution = self.patch_embed.patches_resolution
481 self.patches_resolution = patches_resolution
482
483 # stochastic depth
484 dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
485 sum(depths))] # stochastic depth decay rule
486
487 # build layers
488 self.layers = nn.ModuleList()
489 for i_layer in range(self.num_layers):
490 kwargs = dict(dim=embed_dims[i_layer],
491 input_resolution=(patches_resolution[0] // (2 ** i_layer),
492 patches_resolution[1] // (2 ** i_layer)),
493 depth=depths[i_layer],
494 drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
495 downsample=PatchMerging if (
496 i_layer < self.num_layers - 1) else None,
497 use_checkpoint=use_checkpoint,
498 out_dim=embed_dims[min(
499 i_layer + 1, len(embed_dims) - 1)],
500 activation=activation,
501 )
502 if i_layer == 0:
503 layer = ConvLayer(
504 conv_expand_ratio=mbconv_expand_ratio,
505 **kwargs,
506 )
507 else:
508 layer = BasicLayer(
509 num_heads=num_heads[i_layer],
510 window_size=window_sizes[i_layer],

Callers 1

build_modelFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected