| 17 | |
| 18 | |
| 19 | class CaffeLayerProcessor(object): |
| 20 | |
| 21 | def __init__(self, net): |
| 22 | self.net = net |
| 23 | self.layer_names = net._layer_names |
| 24 | self.param_dict = {} |
| 25 | self.processors = { |
| 26 | 'Convolution': self.proc_conv, |
| 27 | 'InnerProduct': self.proc_fc, |
| 28 | 'BatchNorm': self.proc_bn, |
| 29 | 'Scale': self.proc_scale |
| 30 | } |
| 31 | |
| 32 | def process(self): |
| 33 | for idx, layer in enumerate(self.net.layers): |
| 34 | param = layer.blobs |
| 35 | name = self.layer_names[idx] |
| 36 | if layer.type in self.processors: |
| 37 | logger.info("Processing layer {} of type {}".format( |
| 38 | name, layer.type)) |
| 39 | dic = self.processors[layer.type](idx, name, param) |
| 40 | self.param_dict.update(dic) |
| 41 | elif len(layer.blobs) != 0: |
| 42 | logger.warn( |
| 43 | "{} layer contains parameters but is not supported!".format(layer.type)) |
| 44 | return self.param_dict |
| 45 | |
| 46 | def proc_conv(self, idx, name, param): |
| 47 | assert len(param) <= 2 |
| 48 | assert param[0].data.ndim == 4 |
| 49 | # caffe: ch_out, ch_in, h, w |
| 50 | W = param[0].data.transpose(2, 3, 1, 0) |
| 51 | if len(param) == 1: |
| 52 | return {name + '/W': W} |
| 53 | else: |
| 54 | return {name + '/W': W, |
| 55 | name + '/b': param[1].data} |
| 56 | |
| 57 | def proc_fc(self, idx, name, param): |
| 58 | # TODO caffe has an 'transpose' option for fc/W |
| 59 | assert len(param) == 2 |
| 60 | prev_layer_name = self.net.bottom_names[name][0] |
| 61 | prev_layer_output = self.net.blobs[prev_layer_name].data |
| 62 | if prev_layer_output.ndim == 4: |
| 63 | logger.info("FC layer {} takes spatial data.".format(name)) |
| 64 | W = param[0].data |
| 65 | # original: outx(CxHxW) |
| 66 | W = W.reshape((-1,) + prev_layer_output.shape[1:]).transpose(2, 3, 1, 0) |
| 67 | # become: (HxWxC)xout |
| 68 | else: |
| 69 | W = param[0].data.transpose() |
| 70 | return {name + '/W': W, |
| 71 | name + '/b': param[1].data} |
| 72 | |
| 73 | def proc_bn(self, idx, name, param): |
| 74 | scale_factor = param[2].data[0] |
| 75 | return {name + '/mean/EMA': param[0].data / scale_factor, |
| 76 | name + '/variance/EMA': param[1].data / scale_factor} |