| 1102 | |
| 1103 | class TorchPool2DLayer(nn.Module): |
| 1104 | def __init__(self, in_channels, hparams, **kwargs): |
| 1105 | super(TorchPool2DLayer, self).__init__() |
| 1106 | |
| 1107 | if hparams["mode"] == "max": |
| 1108 | self.layer1 = nn.MaxPool2d( |
| 1109 | kernel_size=hparams["kernel_shape"], |
| 1110 | padding=hparams["pad"], |
| 1111 | stride=hparams["stride"], |
| 1112 | ) |
| 1113 | elif hparams["mode"] == "average": |
| 1114 | self.layer1 = nn.AvgPool2d( |
| 1115 | kernel_size=hparams["kernel_shape"], |
| 1116 | padding=hparams["pad"], |
| 1117 | stride=hparams["stride"], |
| 1118 | ) |
| 1119 | |
| 1120 | def forward(self, X): |
| 1121 | # (N, H, W, C) -> (N, C, H, W) |