The :class:`TernaryConv2d` class is a 2D ternary CNN layer, which weights are either -1 or 1 or 0 while inference. Note that, the bias vector would not be tenarized. Parameters ---------- n_filter : int The number of filters. filter_size : tuple of int The
| 13 | |
| 14 | |
| 15 | class TernaryConv2d(Layer): |
| 16 | """ |
| 17 | The :class:`TernaryConv2d` class is a 2D ternary CNN layer, which weights are either -1 or 1 or 0 while inference. |
| 18 | |
| 19 | Note that, the bias vector would not be tenarized. |
| 20 | |
| 21 | Parameters |
| 22 | ---------- |
| 23 | n_filter : int |
| 24 | The number of filters. |
| 25 | filter_size : tuple of int |
| 26 | The filter size (height, width). |
| 27 | strides : tuple of int |
| 28 | The sliding window strides of corresponding input dimensions. |
| 29 | It must be in the same order as the ``shape`` parameter. |
| 30 | act : activation function |
| 31 | The activation function of this layer. |
| 32 | padding : str |
| 33 | The padding algorithm type: "SAME" or "VALID". |
| 34 | use_gemm : boolean |
| 35 | If True, use gemm instead of ``tf.matmul`` for inference. |
| 36 | TODO: support gemm |
| 37 | data_format : str |
| 38 | "channels_last" (NHWC, default) or "channels_first" (NCHW). |
| 39 | dilation_rate : tuple of int |
| 40 | Specifying the dilation rate to use for dilated convolution. |
| 41 | W_init : initializer |
| 42 | The initializer for the the weight matrix. |
| 43 | b_init : initializer or None |
| 44 | The initializer for the the bias vector. If None, skip biases. |
| 45 | in_channels : int |
| 46 | The number of in channels. |
| 47 | name : None or str |
| 48 | A unique layer name. |
| 49 | |
| 50 | Examples |
| 51 | --------- |
| 52 | With TensorLayer |
| 53 | |
| 54 | >>> net = tl.layers.Input([8, 12, 12, 32], name='input') |
| 55 | >>> ternaryconv2d = tl.layers.QuanConv2d( |
| 56 | ... n_filter=64, filter_size=(5, 5), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='ternaryconv2d' |
| 57 | ... )(net) |
| 58 | >>> print(ternaryconv2d) |
| 59 | >>> output shape : (8, 12, 12, 64) |
| 60 | |
| 61 | """ |
| 62 | |
| 63 | def __init__( |
| 64 | self, |
| 65 | n_filter=32, |
| 66 | filter_size=(3, 3), |
| 67 | strides=(1, 1), |
| 68 | act=None, |
| 69 | padding='SAME', |
| 70 | use_gemm=False, |
| 71 | data_format="channels_last", |
| 72 | dilation_rate=(1, 1), |