(self, w_in, w_out, norm, activation_class)
| 77 | """Simple stem for ImageNet: 3x3, BN, AF.""" |
| 78 | |
| 79 | def __init__(self, w_in, w_out, norm, activation_class): |
| 80 | super().__init__(w_in, w_out, 2) |
| 81 | self.conv = conv2d(w_in, w_out, 3, stride=2) |
| 82 | self.bn = get_norm(norm, w_out) |
| 83 | self.af = activation_class() |
| 84 | |
| 85 | def forward(self, x): |
| 86 | for layer in self.children(): |