The :class:`Input` class is the starting layer of a neural network. Parameters ---------- shape : tuple (int) Including batch size. dtype: dtype The type of input values. By default, tf.float32. name : None or str A unique layer name.
| 14 | |
| 15 | |
| 16 | class _InputLayer(Layer): |
| 17 | """ |
| 18 | The :class:`Input` class is the starting layer of a neural network. |
| 19 | |
| 20 | Parameters |
| 21 | ---------- |
| 22 | shape : tuple (int) |
| 23 | Including batch size. |
| 24 | dtype: dtype |
| 25 | The type of input values. By default, tf.float32. |
| 26 | name : None or str |
| 27 | A unique layer name. |
| 28 | |
| 29 | """ |
| 30 | |
| 31 | def __init__(self, shape, dtype=tf.float32, name=None): #'input'): |
| 32 | # super(InputLayer, self).__init__(prev_layer=inputs, name=name) |
| 33 | super(_InputLayer, self).__init__(name) |
| 34 | |
| 35 | if isinstance(dtype, str): |
| 36 | try: |
| 37 | dtype = eval(dtype) |
| 38 | except Exception as e: |
| 39 | raise RuntimeError("%s is not a valid dtype for InputLayer." % (dtype)) |
| 40 | if not isinstance(dtype, tf.DType): |
| 41 | raise RuntimeError("%s is not a valid dtype for InputLayer." % (dtype)) |
| 42 | |
| 43 | logging.info("Input %s: %s" % (self.name, str(shape))) |
| 44 | self.shape = shape # shape is needed in __repr__ |
| 45 | |
| 46 | shape_without_none = [_ if _ is not None else 1 for _ in shape] |
| 47 | # self.outputs = self.forward(tl.initializers.random_normal()(shape_without_none)) |
| 48 | outputs = self.forward(tl.initializers.ones()(shape_without_none, dtype=dtype)) |
| 49 | |
| 50 | self._built = True |
| 51 | |
| 52 | self._add_node(outputs, outputs) |
| 53 | |
| 54 | def __repr__(self): |
| 55 | s = 'Input(shape=%s' % str(self.shape) |
| 56 | if self.name is not None: |
| 57 | s += (', name=\'%s\'' % self.name) |
| 58 | s += ')' |
| 59 | return s |
| 60 | |
| 61 | def __call__(self, inputs, *args, **kwargs): |
| 62 | return super(_InputLayer, self).__call__(inputs) |
| 63 | |
| 64 | def build(self, inputs_shape): |
| 65 | pass |
| 66 | |
| 67 | def forward(self, inputs): |
| 68 | return inputs |
| 69 | |
| 70 | |
| 71 | def Input(shape, dtype=tf.float32, name=None): |
no outgoing calls
no test coverage detected
searching dependent graphs…