| 451 | |
| 452 | |
| 453 | class TinyViT(nn.Module): |
| 454 | def __init__(self, img_size=224, in_chans=3, num_classes=1000, |
| 455 | embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], |
| 456 | num_heads=[3, 6, 12, 24], |
| 457 | window_sizes=[7, 7, 14, 7], |
| 458 | mlp_ratio=4., |
| 459 | drop_rate=0., |
| 460 | drop_path_rate=0.1, |
| 461 | use_checkpoint=False, |
| 462 | mbconv_expand_ratio=4.0, |
| 463 | local_conv_size=3, |
| 464 | layer_lr_decay=1.0, |
| 465 | ): |
| 466 | super().__init__() |
| 467 | |
| 468 | self.num_classes = num_classes |
| 469 | self.depths = depths |
| 470 | self.num_layers = len(depths) |
| 471 | self.mlp_ratio = mlp_ratio |
| 472 | |
| 473 | activation = nn.GELU |
| 474 | |
| 475 | self.patch_embed = PatchEmbed(in_chans=in_chans, |
| 476 | embed_dim=embed_dims[0], |
| 477 | resolution=img_size, |
| 478 | activation=activation) |
| 479 | |
| 480 | patches_resolution = self.patch_embed.patches_resolution |
| 481 | self.patches_resolution = patches_resolution |
| 482 | |
| 483 | # stochastic depth |
| 484 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, |
| 485 | sum(depths))] # stochastic depth decay rule |
| 486 | |
| 487 | # build layers |
| 488 | self.layers = nn.ModuleList() |
| 489 | for i_layer in range(self.num_layers): |
| 490 | kwargs = dict(dim=embed_dims[i_layer], |
| 491 | input_resolution=(patches_resolution[0] // (2 ** i_layer), |
| 492 | patches_resolution[1] // (2 ** i_layer)), |
| 493 | depth=depths[i_layer], |
| 494 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], |
| 495 | downsample=PatchMerging if ( |
| 496 | i_layer < self.num_layers - 1) else None, |
| 497 | use_checkpoint=use_checkpoint, |
| 498 | out_dim=embed_dims[min( |
| 499 | i_layer + 1, len(embed_dims) - 1)], |
| 500 | activation=activation, |
| 501 | ) |
| 502 | if i_layer == 0: |
| 503 | layer = ConvLayer( |
| 504 | conv_expand_ratio=mbconv_expand_ratio, |
| 505 | **kwargs, |
| 506 | ) |
| 507 | else: |
| 508 | layer = BasicLayer( |
| 509 | num_heads=num_heads[i_layer], |
| 510 | window_size=window_sizes[i_layer], |