Raises an exception if the tensor rank is not of the expected rank. Args: tensor: A tf.Tensor to check the rank of. expected_rank: Python integer or list of integers, expected rank. name: Optional name of the tensor for the error message. Raises: ValueError: If the
(tensor, expected_rank, name=None)
| 31 | |
| 32 | |
| 33 | def assert_rank(tensor, expected_rank, name=None): |
| 34 | """Raises an exception if the tensor rank is not of the expected rank. |
| 35 | |
| 36 | Args: |
| 37 | tensor: A tf.Tensor to check the rank of. |
| 38 | expected_rank: Python integer or list of integers, expected rank. |
| 39 | name: Optional name of the tensor for the error message. |
| 40 | |
| 41 | Raises: |
| 42 | ValueError: If the expected shape doesn't match the actual shape. |
| 43 | """ |
| 44 | if name is None: |
| 45 | name = tensor.name |
| 46 | |
| 47 | expected_rank_dict = {} |
| 48 | if isinstance(expected_rank, six.integer_types): |
| 49 | expected_rank_dict[expected_rank] = True |
| 50 | else: |
| 51 | for x in expected_rank: |
| 52 | expected_rank_dict[x] = True |
| 53 | |
| 54 | actual_rank = tensor.shape.ndims |
| 55 | if actual_rank not in expected_rank_dict: |
| 56 | scope_name = tf.get_variable_scope().name |
| 57 | raise ValueError( |
| 58 | "For the tensor `%s` in scope `%s`, the actual rank " |
| 59 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % |
| 60 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) |
| 61 | |
| 62 | |
| 63 | def get_shape_list(tensor, expected_rank=None, name=None): |