MCPcopy
hub / github.com/lllyasviel/Paints-UNDO / __init__

Method __init__

diffusers_vdm/vae.py:155–216  ·  view source on GitHub ↗
(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, double_z=True, **kwargs)

Source from the content-addressed store, hash-verified

153
154class Encoder(nn.Module):
155 def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
156 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
157 resolution, z_channels, double_z=True, **kwargs):
158 super().__init__()
159 self.ch = ch
160 self.temb_ch = 0
161 self.num_resolutions = len(ch_mult)
162 self.num_res_blocks = num_res_blocks
163 self.resolution = resolution
164 self.in_channels = in_channels
165
166 # downsampling
167 self.conv_in = torch.nn.Conv2d(in_channels,
168 self.ch,
169 kernel_size=3,
170 stride=1,
171 padding=1)
172
173 curr_res = resolution
174 in_ch_mult = (1,) + tuple(ch_mult)
175 self.in_ch_mult = in_ch_mult
176 self.down = nn.ModuleList()
177 for i_level in range(self.num_resolutions):
178 block = nn.ModuleList()
179 attn = nn.ModuleList()
180 block_in = ch * in_ch_mult[i_level]
181 block_out = ch * ch_mult[i_level]
182 for i_block in range(self.num_res_blocks):
183 block.append(ResnetBlock(in_channels=block_in,
184 out_channels=block_out,
185 temb_channels=self.temb_ch,
186 dropout=dropout))
187 block_in = block_out
188 if curr_res in attn_resolutions:
189 attn.append(Attention(block_in))
190 down = nn.Module()
191 down.block = block
192 down.attn = attn
193 if i_level != self.num_resolutions - 1:
194 down.downsample = EncoderDownSampleBlock(block_in, resamp_with_conv)
195 curr_res = curr_res // 2
196 self.down.append(down)
197
198 # middle
199 self.mid = nn.Module()
200 self.mid.block_1 = ResnetBlock(in_channels=block_in,
201 out_channels=block_in,
202 temb_channels=self.temb_ch,
203 dropout=dropout)
204 self.mid.attn_1 = Attention(block_in)
205 self.mid.block_2 = ResnetBlock(in_channels=block_in,
206 out_channels=block_in,
207 temb_channels=self.temb_ch,
208 dropout=dropout)
209
210 # end
211 self.norm_out = GroupNorm(block_in)
212 self.conv_out = torch.nn.Conv2d(block_in,

Callers 10

__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45

Calls 4

ResnetBlockClass · 0.85
AttentionClass · 0.85
GroupNormFunction · 0.85

Tested by

no test coverage detected