MCPcopy
hub / github.com/lllyasviel/Paints-UNDO / ResnetBlock

Class ResnetBlock

diffusers_vdm/vae.py:92–151  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

90
91
92class ResnetBlock(nn.Module):
93 def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
94 dropout, temb_channels=512):
95 super().__init__()
96 self.in_channels = in_channels
97 out_channels = in_channels if out_channels is None else out_channels
98 self.out_channels = out_channels
99 self.use_conv_shortcut = conv_shortcut
100
101 self.norm1 = GroupNorm(in_channels)
102 self.conv1 = torch.nn.Conv2d(in_channels,
103 out_channels,
104 kernel_size=3,
105 stride=1,
106 padding=1)
107 if temb_channels > 0:
108 self.temb_proj = torch.nn.Linear(temb_channels,
109 out_channels)
110 self.norm2 = GroupNorm(out_channels)
111 self.dropout = torch.nn.Dropout(dropout)
112 self.conv2 = torch.nn.Conv2d(out_channels,
113 out_channels,
114 kernel_size=3,
115 stride=1,
116 padding=1)
117 if self.in_channels != self.out_channels:
118 if self.use_conv_shortcut:
119 self.conv_shortcut = torch.nn.Conv2d(in_channels,
120 out_channels,
121 kernel_size=3,
122 stride=1,
123 padding=1)
124 else:
125 self.nin_shortcut = torch.nn.Conv2d(in_channels,
126 out_channels,
127 kernel_size=1,
128 stride=1,
129 padding=0)
130
131 def forward(self, x, temb):
132 h = x
133 h = self.norm1(h)
134 h = nonlinearity(h)
135 h = self.conv1(h)
136
137 if temb is not None:
138 h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
139
140 h = self.norm2(h)
141 h = nonlinearity(h)
142 h = self.dropout(h)
143 h = self.conv2(h)
144
145 if self.in_channels != self.out_channels:
146 if self.use_conv_shortcut:
147 x = self.conv_shortcut(x)
148 else:
149 x = self.nin_shortcut(x)

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected