Raises an exception if the tensor rank is not of the expected rank. Args: tensor: A tf.Tensor to check the rank of. expected_rank: Python integer or list of integers, expected rank. name: Optional name of the tensor for the error message. Raises: ValueError: If the
(tensor, expected_rank, name=None)
| 29 | |
| 30 | |
| 31 | def assert_rank(tensor, expected_rank, name=None): |
| 32 | """Raises an exception if the tensor rank is not of the expected rank. |
| 33 | |
| 34 | Args: |
| 35 | tensor: A tf.Tensor to check the rank of. |
| 36 | expected_rank: Python integer or list of integers, expected rank. |
| 37 | name: Optional name of the tensor for the error message. |
| 38 | |
| 39 | Raises: |
| 40 | ValueError: If the expected shape doesn't match the actual shape. |
| 41 | """ |
| 42 | if name is None: |
| 43 | name = tensor.name |
| 44 | |
| 45 | expected_rank_dict = {} |
| 46 | if isinstance(expected_rank, six.integer_types): |
| 47 | expected_rank_dict[expected_rank] = True |
| 48 | else: |
| 49 | for x in expected_rank: |
| 50 | expected_rank_dict[x] = True |
| 51 | |
| 52 | actual_rank = tensor.shape.ndims |
| 53 | if actual_rank not in expected_rank_dict: |
| 54 | scope_name = tf.compat.v1.get_variable().name |
| 55 | raise ValueError( |
| 56 | "For the tensor `%s` in scope `%s`, the actual rank " |
| 57 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % |
| 58 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) |
| 59 | |
| 60 | |
| 61 | def get_shape_list(tensor, expected_rank=None, name=None): |