MCPcopy
hub / github.com/Turing-Project/WriteGPT / get_shape_list

Function get_shape_list

LanguageNetwork/GPT2/scripts/utils.py:61–95  ·  view source on GitHub ↗

Returns a list of the shape of tensor, preferring static dimensions. Args: tensor: A tf.Tensor object to find the shape of. expected_rank: (optional) int. The expected rank of `tensor`. If this is specified and the `tensor` has a different rank, and exception will be

(tensor, expected_rank=None, name=None)

Source from the content-addressed store, hash-verified

59
60
61def get_shape_list(tensor, expected_rank=None, name=None):
62 """Returns a list of the shape of tensor, preferring static dimensions.
63
64 Args:
65 tensor: A tf.Tensor object to find the shape of.
66 expected_rank: (optional) int. The expected rank of `tensor`. If this is
67 specified and the `tensor` has a different rank, and exception will be
68 thrown.
69 name: Optional name of the tensor for the error message.
70
71 Returns:
72 A list of dimensions of the shape of tensor. All static dimensions will
73 be returned as python integers, and dynamic dimensions will be returned
74 as tf.Tensor scalars.
75 """
76 if name is None:
77 name = tensor.name
78
79 if expected_rank is not None:
80 assert_rank(tensor, expected_rank, name)
81
82 shape = tensor.shape.as_list()
83
84 non_static_indexes = []
85 for (index, dim) in enumerate(shape):
86 if dim is None:
87 non_static_indexes.append(index)
88
89 if not non_static_indexes:
90 return shape
91
92 dyn_shape = tf.shape(tensor)
93 for index in non_static_indexes:
94 shape[index] = dyn_shape[index]
95 return shape
96
97
98def gelu(input_tensor):

Callers 13

attention_layerFunction · 0.90
residual_mlp_layerFunction · 0.90
embedFunction · 0.90
_top_p_sampleFunction · 0.90
_top_k_sampleFunction · 0.90
__init__Method · 0.90
model_fnFunction · 0.90
sample_stepFunction · 0.90
initialize_from_contextFunction · 0.90
sampleFunction · 0.90
condFunction · 0.90

Calls 2

shapeMethod · 0.80
assert_rankFunction · 0.70

Tested by

no test coverage detected