MCPcopy Index your code
hub / github.com/LTH14/mar / __init__

Method __init__

models/vae.py:162–243  ·  view source on GitHub ↗
(
        self,
        *,
        ch=128,
        out_ch=3,
        ch_mult=(1, 1, 2, 2, 4),
        num_res_blocks=2,
        attn_resolutions=(16,),
        dropout=0.0,
        resamp_with_conv=True,
        in_channels=3,
        resolution=256,
        z_channels=16,
        double_z=True,
        **ignore_kwargs,
    )

Source from the content-addressed store, hash-verified

160
161class Encoder(nn.Module):
162 def __init__(
163 self,
164 *,
165 ch=128,
166 out_ch=3,
167 ch_mult=(1, 1, 2, 2, 4),
168 num_res_blocks=2,
169 attn_resolutions=(16,),
170 dropout=0.0,
171 resamp_with_conv=True,
172 in_channels=3,
173 resolution=256,
174 z_channels=16,
175 double_z=True,
176 **ignore_kwargs,
177 ):
178 super().__init__()
179 self.ch = ch
180 self.temb_ch = 0
181 self.num_resolutions = len(ch_mult)
182 self.num_res_blocks = num_res_blocks
183 self.resolution = resolution
184 self.in_channels = in_channels
185
186 # downsampling
187 self.conv_in = torch.nn.Conv2d(
188 in_channels, self.ch, kernel_size=3, stride=1, padding=1
189 )
190
191 curr_res = resolution
192 in_ch_mult = (1,) + tuple(ch_mult)
193 self.down = nn.ModuleList()
194 for i_level in range(self.num_resolutions):
195 block = nn.ModuleList()
196 attn = nn.ModuleList()
197 block_in = ch * in_ch_mult[i_level]
198 block_out = ch * ch_mult[i_level]
199 for i_block in range(self.num_res_blocks):
200 block.append(
201 ResnetBlock(
202 in_channels=block_in,
203 out_channels=block_out,
204 temb_channels=self.temb_ch,
205 dropout=dropout,
206 )
207 )
208 block_in = block_out
209 if curr_res in attn_resolutions:
210 attn.append(AttnBlock(block_in))
211 down = nn.Module()
212 down.block = block
213 down.attn = attn
214 if i_level != self.num_resolutions - 1:
215 down.downsample = Downsample(block_in, resamp_with_conv)
216 curr_res = curr_res // 2
217 self.down.append(down)
218
219 # middle

Callers 6

__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45

Calls 4

ResnetBlockClass · 0.85
AttnBlockClass · 0.85
DownsampleClass · 0.85
NormalizeFunction · 0.85

Tested by

no test coverage detected