MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / assert_rank

Function assert_rank

LanguageNetwork/GPT2/train/utils.py:33–60  ·  view source on GitHub ↗

Raises an exception if the tensor rank is not of the expected rank. Args: tensor: A tf.Tensor to check the rank of. expected_rank: Python integer or list of integers, expected rank. name: Optional name of the tensor for the error message. Raises: ValueError: If the

(tensor, expected_rank, name=None)

Source from the content-addressed store, hash-verified

31
32
33def assert_rank(tensor, expected_rank, name=None):
34 """Raises an exception if the tensor rank is not of the expected rank.
35
36 Args:
37 tensor: A tf.Tensor to check the rank of.
38 expected_rank: Python integer or list of integers, expected rank.
39 name: Optional name of the tensor for the error message.
40
41 Raises:
42 ValueError: If the expected shape doesn't match the actual shape.
43 """
44 if name is None:
45 name = tensor.name
46
47 expected_rank_dict = {}
48 if isinstance(expected_rank, six.integer_types):
49 expected_rank_dict[expected_rank] = True
50 else:
51 for x in expected_rank:
52 expected_rank_dict[x] = True
53
54 actual_rank = tensor.shape.ndims
55 if actual_rank not in expected_rank_dict:
56 scope_name = tf.get_variable_scope().name
57 raise ValueError(
58 "For the tensor `%s` in scope `%s`, the actual rank "
59 "`%d` (shape = %s) is not equal to the expected rank `%s`" %
60 (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
61
62
63def get_shape_list(tensor, expected_rank=None, name=None):

Callers 1

get_shape_listFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected