MCPcopy Index your code
hub / github.com/LTH14/mar / ResnetBlock

Class ResnetBlock

models/vae.py:55–112  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

53
54
55class ResnetBlock(nn.Module):
56 def __init__(
57 self,
58 *,
59 in_channels,
60 out_channels=None,
61 conv_shortcut=False,
62 dropout,
63 temb_channels=512,
64 ):
65 super().__init__()
66 self.in_channels = in_channels
67 out_channels = in_channels if out_channels is None else out_channels
68 self.out_channels = out_channels
69 self.use_conv_shortcut = conv_shortcut
70
71 self.norm1 = Normalize(in_channels)
72 self.conv1 = torch.nn.Conv2d(
73 in_channels, out_channels, kernel_size=3, stride=1, padding=1
74 )
75 if temb_channels > 0:
76 self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
77 self.norm2 = Normalize(out_channels)
78 self.dropout = torch.nn.Dropout(dropout)
79 self.conv2 = torch.nn.Conv2d(
80 out_channels, out_channels, kernel_size=3, stride=1, padding=1
81 )
82 if self.in_channels != self.out_channels:
83 if self.use_conv_shortcut:
84 self.conv_shortcut = torch.nn.Conv2d(
85 in_channels, out_channels, kernel_size=3, stride=1, padding=1
86 )
87 else:
88 self.nin_shortcut = torch.nn.Conv2d(
89 in_channels, out_channels, kernel_size=1, stride=1, padding=0
90 )
91
92 def forward(self, x, temb):
93 h = x
94 h = self.norm1(h)
95 h = nonlinearity(h)
96 h = self.conv1(h)
97
98 if temb is not None:
99 h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
100
101 h = self.norm2(h)
102 h = nonlinearity(h)
103 h = self.dropout(h)
104 h = self.conv2(h)
105
106 if self.in_channels != self.out_channels:
107 if self.use_conv_shortcut:
108 x = self.conv_shortcut(x)
109 else:
110 x = self.nin_shortcut(x)
111
112 return x + h

Callers 2

__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected