MCPcopy
hub / github.com/Turing-Project/WriteGPT / assert_rank

Function assert_rank

LanguageNetwork/GPT2/scripts/utils.py:31–58  ·  view source on GitHub ↗

Raises an exception if the tensor rank is not of the expected rank. Args: tensor: A tf.Tensor to check the rank of. expected_rank: Python integer or list of integers, expected rank. name: Optional name of the tensor for the error message. Raises: ValueError: If the

(tensor, expected_rank, name=None)

Source from the content-addressed store, hash-verified

29
30
31def assert_rank(tensor, expected_rank, name=None):
32 """Raises an exception if the tensor rank is not of the expected rank.
33
34 Args:
35 tensor: A tf.Tensor to check the rank of.
36 expected_rank: Python integer or list of integers, expected rank.
37 name: Optional name of the tensor for the error message.
38
39 Raises:
40 ValueError: If the expected shape doesn't match the actual shape.
41 """
42 if name is None:
43 name = tensor.name
44
45 expected_rank_dict = {}
46 if isinstance(expected_rank, six.integer_types):
47 expected_rank_dict[expected_rank] = True
48 else:
49 for x in expected_rank:
50 expected_rank_dict[x] = True
51
52 actual_rank = tensor.shape.ndims
53 if actual_rank not in expected_rank_dict:
54 scope_name = tf.compat.v1.get_variable().name
55 raise ValueError(
56 "For the tensor `%s` in scope `%s`, the actual rank "
57 "`%d` (shape = %s) is not equal to the expected rank `%s`" %
58 (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
59
60
61def get_shape_list(tensor, expected_rank=None, name=None):

Callers 1

get_shape_listFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected