Internally used by layer registry, to print shapes of inputs/outputs of layers. Args: tensors (list or tf.Tensor): a tensor or a list of tensors Returns: str: a string to describe the shape
(tensors)
| 68 | |
| 69 | |
| 70 | def get_shape_str(tensors): |
| 71 | """ |
| 72 | Internally used by layer registry, to print shapes of inputs/outputs of layers. |
| 73 | |
| 74 | Args: |
| 75 | tensors (list or tf.Tensor): a tensor or a list of tensors |
| 76 | Returns: |
| 77 | str: a string to describe the shape |
| 78 | """ |
| 79 | if isinstance(tensors, (list, tuple)): |
| 80 | for v in tensors: |
| 81 | assert isinstance(v, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(v)) |
| 82 | shape_str = ", ".join(map(get_shape_str, tensors)) |
| 83 | else: |
| 84 | assert isinstance(tensors, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(tensors)) |
| 85 | shape_str = str(tensors.get_shape().as_list()).replace("None", "?") |
| 86 | return shape_str |
no test coverage detected