| 122 | # Insert a BN after conv3x3 (rbr_reparam). With no reasonable initialization of BN, the model may break down. |
| 123 | # So you have to load the weights obtained through the BN statistics (please see the function "insert_bn" in this file). |
| 124 | def directly_insert_bn_without_init(model): |
| 125 | for n, block in model.named_modules(): |
| 126 | if isinstance(block, RepVGGBlock): |
| 127 | print('directly insert a BN with no initialization: ', n) |
| 128 | assert hasattr(block, 'rbr_reparam') |
| 129 | convbn = nn.Sequential() |
| 130 | convbn.add_module('conv', nn.Conv2d(block.rbr_reparam.in_channels, block.rbr_reparam.out_channels, |
| 131 | block.rbr_reparam.kernel_size, |
| 132 | block.rbr_reparam.stride, block.rbr_reparam.padding, |
| 133 | block.rbr_reparam.dilation, |
| 134 | block.rbr_reparam.groups, bias=False)) # Note bias=False |
| 135 | convbn.add_module('bn', nn.BatchNorm2d(block.rbr_reparam.out_channels)) |
| 136 | # ==================== |
| 137 | convbn.add_module('relu', nn.ReLU()) |
| 138 | # TODO we moved ReLU from "block.nonlinearity" into "rbr_reparam" (nn.Sequential). This makes it more convenient to fuse operators (see RepVGGWholeQuant.fuse_model) using off-the-shelf APIs. |
| 139 | block.nonlinearity = nn.Identity() |
| 140 | #========================== |
| 141 | block.__delattr__('rbr_reparam') |
| 142 | block.rbr_reparam = convbn |
| 143 | |
| 144 | |
| 145 | def insert_bn(): |