MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / get_shape_list

Function get_shape_list

LanguageNetwork/GPT2/train/utils.py:63–97  ·  view source on GitHub ↗

Returns a list of the shape of tensor, preferring static dimensions. Args: tensor: A tf.Tensor object to find the shape of. expected_rank: (optional) int. The expected rank of `tensor`. If this is specified and the `tensor` has a different rank, and exception will be

(tensor, expected_rank=None, name=None)

Source from the content-addressed store, hash-verified

61
62
63def get_shape_list(tensor, expected_rank=None, name=None):
64 """Returns a list of the shape of tensor, preferring static dimensions.
65
66 Args:
67 tensor: A tf.Tensor object to find the shape of.
68 expected_rank: (optional) int. The expected rank of `tensor`. If this is
69 specified and the `tensor` has a different rank, and exception will be
70 thrown.
71 name: Optional name of the tensor for the error message.
72
73 Returns:
74 A list of dimensions of the shape of tensor. All static dimensions will
75 be returned as python integers, and dynamic dimensions will be returned
76 as tf.Tensor scalars.
77 """
78 if name is None:
79 name = tensor.name
80
81 if expected_rank is not None:
82 assert_rank(tensor, expected_rank, name)
83
84 shape = tensor.shape.as_list()
85
86 non_static_indexes = []
87 for (index, dim) in enumerate(shape):
88 if dim is None:
89 non_static_indexes.append(index)
90
91 if not non_static_indexes:
92 return shape
93
94 dyn_shape = tf.shape(tensor)
95 for index in non_static_indexes:
96 shape[index] = dyn_shape[index]
97 return shape
98
99
100def gelu(input_tensor):

Callers 13

attention_layerFunction · 0.90
residual_mlp_layerFunction · 0.90
embedFunction · 0.90
_top_p_sampleFunction · 0.90
_top_k_sampleFunction · 0.90
__init__Method · 0.90
model_fnFunction · 0.90
sample_stepFunction · 0.90
initialize_from_contextFunction · 0.90
sampleFunction · 0.90
condFunction · 0.90

Calls 2

shapeMethod · 0.80
assert_rankFunction · 0.70

Tested by

no test coverage detected