A layer that concats multiple tensors according to given axis. Parameters ---------- concat_dim : int The dimension to concatenate. name : None or str A unique layer name. Examples ---------- >>> class CustomModel(tl.models.Model): >>> def __init
| 13 | |
| 14 | |
| 15 | class Concat(Layer): |
| 16 | """A layer that concats multiple tensors according to given axis. |
| 17 | |
| 18 | Parameters |
| 19 | ---------- |
| 20 | concat_dim : int |
| 21 | The dimension to concatenate. |
| 22 | name : None or str |
| 23 | A unique layer name. |
| 24 | |
| 25 | Examples |
| 26 | ---------- |
| 27 | >>> class CustomModel(tl.models.Model): |
| 28 | >>> def __init__(self): |
| 29 | >>> super(CustomModel, self).__init__(name="custom") |
| 30 | >>> self.dense1 = tl.layers.Dense(in_channels=20, n_units=10, act=tf.nn.relu, name='relu1_1') |
| 31 | >>> self.dense2 = tl.layers.Dense(in_channels=20, n_units=10, act=tf.nn.relu, name='relu2_1') |
| 32 | >>> self.concat = tl.layers.Concat(concat_dim=1, name='concat_layer') |
| 33 | |
| 34 | >>> def forward(self, inputs): |
| 35 | >>> d1 = self.dense1(inputs) |
| 36 | >>> d2 = self.dense2(inputs) |
| 37 | >>> outputs = self.concat([d1, d2]) |
| 38 | >>> return outputs |
| 39 | |
| 40 | """ |
| 41 | |
| 42 | def __init__( |
| 43 | self, |
| 44 | concat_dim=-1, |
| 45 | name=None, #'concat', |
| 46 | ): |
| 47 | |
| 48 | super(Concat, self).__init__(name) |
| 49 | self.concat_dim = concat_dim |
| 50 | |
| 51 | self.build(None) |
| 52 | self._built = True |
| 53 | |
| 54 | logging.info("Concat %s: concat_dim: %d" % (self.name, concat_dim)) |
| 55 | |
| 56 | def __repr__(self): |
| 57 | s = ('{classname}(concat_dim={concat_dim})') |
| 58 | return s.format(classname=self.__class__.__name__, **self.__dict__) |
| 59 | |
| 60 | def build(self, inputs_shape): |
| 61 | pass |
| 62 | |
| 63 | # @tf.function |
| 64 | def forward(self, inputs): |
| 65 | """ |
| 66 | |
| 67 | prev_layer : list of :class:`Layer` |
| 68 | List of layers to concatenate. |
| 69 | """ |
| 70 | outputs = tf.concat(inputs, self.concat_dim, name=self.name) |
| 71 | |
| 72 | return outputs |
no outgoing calls
searching dependent graphs…