MCPcopy
hub / github.com/ermongroup/ddim / ResnetBlock

Class ResnetBlock

models/diffusion.py:77–134  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

75
76
77class ResnetBlock(nn.Module):
78 def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
79 dropout, temb_channels=512):
80 super().__init__()
81 self.in_channels = in_channels
82 out_channels = in_channels if out_channels is None else out_channels
83 self.out_channels = out_channels
84 self.use_conv_shortcut = conv_shortcut
85
86 self.norm1 = Normalize(in_channels)
87 self.conv1 = torch.nn.Conv2d(in_channels,
88 out_channels,
89 kernel_size=3,
90 stride=1,
91 padding=1)
92 self.temb_proj = torch.nn.Linear(temb_channels,
93 out_channels)
94 self.norm2 = Normalize(out_channels)
95 self.dropout = torch.nn.Dropout(dropout)
96 self.conv2 = torch.nn.Conv2d(out_channels,
97 out_channels,
98 kernel_size=3,
99 stride=1,
100 padding=1)
101 if self.in_channels != self.out_channels:
102 if self.use_conv_shortcut:
103 self.conv_shortcut = torch.nn.Conv2d(in_channels,
104 out_channels,
105 kernel_size=3,
106 stride=1,
107 padding=1)
108 else:
109 self.nin_shortcut = torch.nn.Conv2d(in_channels,
110 out_channels,
111 kernel_size=1,
112 stride=1,
113 padding=0)
114
115 def forward(self, x, temb):
116 h = x
117 h = self.norm1(h)
118 h = nonlinearity(h)
119 h = self.conv1(h)
120
121 h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
122
123 h = self.norm2(h)
124 h = nonlinearity(h)
125 h = self.dropout(h)
126 h = self.conv2(h)
127
128 if self.in_channels != self.out_channels:
129 if self.use_conv_shortcut:
130 x = self.conv_shortcut(x)
131 else:
132 x = self.nin_shortcut(x)
133
134 return x+h

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected