MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / CaffeLayerProcessor

Class CaffeLayerProcessor

tensorpack/utils/loadcaffe.py:19–93  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

17
18
19class CaffeLayerProcessor(object):
20
21 def __init__(self, net):
22 self.net = net
23 self.layer_names = net._layer_names
24 self.param_dict = {}
25 self.processors = {
26 'Convolution': self.proc_conv,
27 'InnerProduct': self.proc_fc,
28 'BatchNorm': self.proc_bn,
29 'Scale': self.proc_scale
30 }
31
32 def process(self):
33 for idx, layer in enumerate(self.net.layers):
34 param = layer.blobs
35 name = self.layer_names[idx]
36 if layer.type in self.processors:
37 logger.info("Processing layer {} of type {}".format(
38 name, layer.type))
39 dic = self.processors[layer.type](idx, name, param)
40 self.param_dict.update(dic)
41 elif len(layer.blobs) != 0:
42 logger.warn(
43 "{} layer contains parameters but is not supported!".format(layer.type))
44 return self.param_dict
45
46 def proc_conv(self, idx, name, param):
47 assert len(param) <= 2
48 assert param[0].data.ndim == 4
49 # caffe: ch_out, ch_in, h, w
50 W = param[0].data.transpose(2, 3, 1, 0)
51 if len(param) == 1:
52 return {name + '/W': W}
53 else:
54 return {name + '/W': W,
55 name + '/b': param[1].data}
56
57 def proc_fc(self, idx, name, param):
58 # TODO caffe has an 'transpose' option for fc/W
59 assert len(param) == 2
60 prev_layer_name = self.net.bottom_names[name][0]
61 prev_layer_output = self.net.blobs[prev_layer_name].data
62 if prev_layer_output.ndim == 4:
63 logger.info("FC layer {} takes spatial data.".format(name))
64 W = param[0].data
65 # original: outx(CxHxW)
66 W = W.reshape((-1,) + prev_layer_output.shape[1:]).transpose(2, 3, 1, 0)
67 # become: (HxWxC)xout
68 else:
69 W = param[0].data.transpose()
70 return {name + '/W': W,
71 name + '/b': param[1].data}
72
73 def proc_bn(self, idx, name, param):
74 scale_factor = param[2].data[0]
75 return {name + '/mean/EMA': param[0].data / scale_factor,
76 name + '/variance/EMA': param[1].data / scale_factor}

Callers 1

load_caffeFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected