(
self,
embed_dim: int = 96, # initial embed dim
num_heads: int = 1, # initial number of heads
drop_path_rate: float = 0.0, # stochastic depth
q_pool: int = 3, # number of q_pool stages
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
dim_mul: float = 2.0, # dim_mul factor at stage shift
head_mul: float = 2.0, # head_mul factor at stage shift
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
# window size per stage, when not using global att.
window_spec: Tuple[int, ...] = (
8,
4,
14,
7,
),
# global attn in these blocks
global_att_blocks: Tuple[int, ...] = (
12,
16,
20,
),
weights_path=None,
return_interm_layers=True, # return feats from every stage
)
| 172 | """ |
| 173 | |
| 174 | def __init__( |
| 175 | self, |
| 176 | embed_dim: int = 96, # initial embed dim |
| 177 | num_heads: int = 1, # initial number of heads |
| 178 | drop_path_rate: float = 0.0, # stochastic depth |
| 179 | q_pool: int = 3, # number of q_pool stages |
| 180 | q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages |
| 181 | stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage |
| 182 | dim_mul: float = 2.0, # dim_mul factor at stage shift |
| 183 | head_mul: float = 2.0, # head_mul factor at stage shift |
| 184 | window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), |
| 185 | # window size per stage, when not using global att. |
| 186 | window_spec: Tuple[int, ...] = ( |
| 187 | 8, |
| 188 | 4, |
| 189 | 14, |
| 190 | 7, |
| 191 | ), |
| 192 | # global attn in these blocks |
| 193 | global_att_blocks: Tuple[int, ...] = ( |
| 194 | 12, |
| 195 | 16, |
| 196 | 20, |
| 197 | ), |
| 198 | weights_path=None, |
| 199 | return_interm_layers=True, # return feats from every stage |
| 200 | ): |
| 201 | super().__init__() |
| 202 | |
| 203 | assert len(stages) == len(window_spec) |
| 204 | self.window_spec = window_spec |
| 205 | |
| 206 | depth = sum(stages) |
| 207 | self.q_stride = q_stride |
| 208 | self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] |
| 209 | assert 0 <= q_pool <= len(self.stage_ends[:-1]) |
| 210 | self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] |
| 211 | self.return_interm_layers = return_interm_layers |
| 212 | |
| 213 | self.patch_embed = PatchEmbed( |
| 214 | embed_dim=embed_dim, |
| 215 | ) |
| 216 | # Which blocks have global att? |
| 217 | self.global_att_blocks = global_att_blocks |
| 218 | |
| 219 | # Windowed positional embedding (https://arxiv.org/abs/2311.05613) |
| 220 | self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size |
| 221 | self.pos_embed = nn.Parameter( |
| 222 | torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) |
| 223 | ) |
| 224 | self.pos_embed_window = nn.Parameter( |
| 225 | torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) |
| 226 | ) |
| 227 | |
| 228 | dpr = [ |
| 229 | x.item() for x in torch.linspace(0, drop_path_rate, depth) |
| 230 | ] # stochastic depth decay rule |
| 231 |
no test coverage detected