| 207 | |
| 208 | |
| 209 | class MLP(GradientModule): |
| 210 | def __init__(self, input_ch=32, W=256, D=8, out_ch=257, skips=[4], actvn=nn.ReLU(), out_actvn=nn.Identity(), |
| 211 | init_weight=nn.Identity(), init_bias=nn.Identity(), init_out_weight=nn.Identity(), init_out_bias=nn.Identity(), dtype=torch.float): |
| 212 | super(MLP, self).__init__() |
| 213 | dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype |
| 214 | self.skips = skips |
| 215 | self.linears = [] |
| 216 | for i in range(D + 1): |
| 217 | I, O = W, W |
| 218 | if i == 0: |
| 219 | I = input_ch |
| 220 | if i in skips: |
| 221 | I = input_ch + W |
| 222 | if i == D: |
| 223 | O = out_ch |
| 224 | self.linears.append(nn.Linear(I, O, dtype=dtype)) |
| 225 | self.linears = nn.ModuleList(self.linears) |
| 226 | self.actvn = get_function(actvn) if isinstance(actvn, str) else actvn |
| 227 | self.out_actvn = get_function(out_actvn) if isinstance(out_actvn, str) else out_actvn |
| 228 | |
| 229 | for i, l in enumerate(self.linears): |
| 230 | if i == len(self.linears) - 1: init_out_weight(l.weight.data) |
| 231 | else: init_weight(l.weight.data) |
| 232 | |
| 233 | for i, l in enumerate(self.linears): |
| 234 | if i == len(self.linears) - 1: init_out_bias(l.bias.data) |
| 235 | else: init_bias(l.bias.data) |
| 236 | |
| 237 | def forward_with_previous(self, input: torch.Tensor): |
| 238 | x = input |
| 239 | for i, l in enumerate(self.linears): |
| 240 | p = x # store output of previous layer |
| 241 | if i in self.skips: |
| 242 | x = torch.cat([x, input], dim=-1) |
| 243 | if i == len(self.linears) - 1: |
| 244 | a = self.out_actvn |
| 245 | else: |
| 246 | a = self.actvn |
| 247 | x = a(l(x)) # actual forward |
| 248 | return x, p |
| 249 | |
| 250 | def forward(self, input: torch.Tensor): |
| 251 | return self.forward_with_previous(input)[0] |
| 252 | |
| 253 | |
| 254 | def setup_deterministic(fix_random=True, # all deterministic, same seed, no benchmarking |