MCPcopy
hub / github.com/yangchris11/samurai / MultiScaleBlock

Class MultiScaleBlock

sam2/sam2/modeling/backbones/hieradet.py:84–166  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

82
83
84class MultiScaleBlock(nn.Module):
85 def __init__(
86 self,
87 dim: int,
88 dim_out: int,
89 num_heads: int,
90 mlp_ratio: float = 4.0,
91 drop_path: float = 0.0,
92 norm_layer: Union[nn.Module, str] = "LayerNorm",
93 q_stride: Tuple[int, int] = None,
94 act_layer: nn.Module = nn.GELU,
95 window_size: int = 0,
96 ):
97 super().__init__()
98
99 if isinstance(norm_layer, str):
100 norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
101
102 self.dim = dim
103 self.dim_out = dim_out
104 self.norm1 = norm_layer(dim)
105
106 self.window_size = window_size
107
108 self.pool, self.q_stride = None, q_stride
109 if self.q_stride:
110 self.pool = nn.MaxPool2d(
111 kernel_size=q_stride, stride=q_stride, ceil_mode=False
112 )
113
114 self.attn = MultiScaleAttention(
115 dim,
116 dim_out,
117 num_heads=num_heads,
118 q_pool=self.pool,
119 )
120 self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
121
122 self.norm2 = norm_layer(dim_out)
123 self.mlp = MLP(
124 dim_out,
125 int(dim_out * mlp_ratio),
126 dim_out,
127 num_layers=2,
128 activation=act_layer,
129 )
130
131 if dim != dim_out:
132 self.proj = nn.Linear(dim, dim_out)
133
134 def forward(self, x: torch.Tensor) -> torch.Tensor:
135 shortcut = x # B, H, W, C
136 x = self.norm1(x)
137
138 # Skip connection
139 if self.dim != self.dim_out:
140 shortcut = do_pool(self.proj(x), self.pool)
141

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected