Returns a list of the shape of tensor, preferring static dimensions. Args: tensor: A tf.Tensor object to find the shape of. expected_rank: (optional) int. The expected rank of `tensor`. If this is specified and the `tensor` has a different rank, and exception will be
(tensor, expected_rank=None, name=None)
| 61 | |
| 62 | |
| 63 | def get_shape_list(tensor, expected_rank=None, name=None): |
| 64 | """Returns a list of the shape of tensor, preferring static dimensions. |
| 65 | |
| 66 | Args: |
| 67 | tensor: A tf.Tensor object to find the shape of. |
| 68 | expected_rank: (optional) int. The expected rank of `tensor`. If this is |
| 69 | specified and the `tensor` has a different rank, and exception will be |
| 70 | thrown. |
| 71 | name: Optional name of the tensor for the error message. |
| 72 | |
| 73 | Returns: |
| 74 | A list of dimensions of the shape of tensor. All static dimensions will |
| 75 | be returned as python integers, and dynamic dimensions will be returned |
| 76 | as tf.Tensor scalars. |
| 77 | """ |
| 78 | if name is None: |
| 79 | name = tensor.name |
| 80 | |
| 81 | if expected_rank is not None: |
| 82 | assert_rank(tensor, expected_rank, name) |
| 83 | |
| 84 | shape = tensor.shape.as_list() |
| 85 | |
| 86 | non_static_indexes = [] |
| 87 | for (index, dim) in enumerate(shape): |
| 88 | if dim is None: |
| 89 | non_static_indexes.append(index) |
| 90 | |
| 91 | if not non_static_indexes: |
| 92 | return shape |
| 93 | |
| 94 | dyn_shape = tf.shape(tensor) |
| 95 | for index in non_static_indexes: |
| 96 | shape[index] = dyn_shape[index] |
| 97 | return shape |
| 98 | |
| 99 | |
| 100 | def gelu(input_tensor): |
no test coverage detected