MCPcopy
hub / github.com/zju3dv/4K4D / MLP

Class MLP

easyvolcap/utils/net_utils.py:209–251  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

207
208
209class MLP(GradientModule):
210 def __init__(self, input_ch=32, W=256, D=8, out_ch=257, skips=[4], actvn=nn.ReLU(), out_actvn=nn.Identity(),
211 init_weight=nn.Identity(), init_bias=nn.Identity(), init_out_weight=nn.Identity(), init_out_bias=nn.Identity(), dtype=torch.float):
212 super(MLP, self).__init__()
213 dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
214 self.skips = skips
215 self.linears = []
216 for i in range(D + 1):
217 I, O = W, W
218 if i == 0:
219 I = input_ch
220 if i in skips:
221 I = input_ch + W
222 if i == D:
223 O = out_ch
224 self.linears.append(nn.Linear(I, O, dtype=dtype))
225 self.linears = nn.ModuleList(self.linears)
226 self.actvn = get_function(actvn) if isinstance(actvn, str) else actvn
227 self.out_actvn = get_function(out_actvn) if isinstance(out_actvn, str) else out_actvn
228
229 for i, l in enumerate(self.linears):
230 if i == len(self.linears) - 1: init_out_weight(l.weight.data)
231 else: init_weight(l.weight.data)
232
233 for i, l in enumerate(self.linears):
234 if i == len(self.linears) - 1: init_out_bias(l.bias.data)
235 else: init_bias(l.bias.data)
236
237 def forward_with_previous(self, input: torch.Tensor):
238 x = input
239 for i, l in enumerate(self.linears):
240 p = x # store output of previous layer
241 if i in self.skips:
242 x = torch.cat([x, input], dim=-1)
243 if i == len(self.linears) - 1:
244 a = self.out_actvn
245 else:
246 a = self.actvn
247 x = a(l(x)) # actual forward
248 return x, p
249
250 def forward(self, input: torch.Tensor):
251 return self.forward_with_previous(input)[0]
252
253
254def setup_deterministic(fix_random=True, # all deterministic, same seed, no benchmarking

Callers 7

__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected