Returns a list of the shape of tensor, preferring static dimensions. Args: tensor: A tf.Tensor object to find the shape of. expected_rank: (optional) int. The expected rank of `tensor`. If this is specified and the `tensor` has a different rank, and exception will be
(tensor, expected_rank=None, name=None)
| 59 | |
| 60 | |
| 61 | def get_shape_list(tensor, expected_rank=None, name=None): |
| 62 | """Returns a list of the shape of tensor, preferring static dimensions. |
| 63 | |
| 64 | Args: |
| 65 | tensor: A tf.Tensor object to find the shape of. |
| 66 | expected_rank: (optional) int. The expected rank of `tensor`. If this is |
| 67 | specified and the `tensor` has a different rank, and exception will be |
| 68 | thrown. |
| 69 | name: Optional name of the tensor for the error message. |
| 70 | |
| 71 | Returns: |
| 72 | A list of dimensions of the shape of tensor. All static dimensions will |
| 73 | be returned as python integers, and dynamic dimensions will be returned |
| 74 | as tf.Tensor scalars. |
| 75 | """ |
| 76 | if name is None: |
| 77 | name = tensor.name |
| 78 | |
| 79 | if expected_rank is not None: |
| 80 | assert_rank(tensor, expected_rank, name) |
| 81 | |
| 82 | shape = tensor.shape.as_list() |
| 83 | |
| 84 | non_static_indexes = [] |
| 85 | for (index, dim) in enumerate(shape): |
| 86 | if dim is None: |
| 87 | non_static_indexes.append(index) |
| 88 | |
| 89 | if not non_static_indexes: |
| 90 | return shape |
| 91 | |
| 92 | dyn_shape = tf.shape(tensor) |
| 93 | for index in non_static_indexes: |
| 94 | shape[index] = dyn_shape[index] |
| 95 | return shape |
| 96 | |
| 97 | |
| 98 | def gelu(input_tensor): |
no test coverage detected