MCPcopy
hub / github.com/Turing-Project/WriteGPT / get_attention_mask

Function get_attention_mask

LanguageNetwork/GPT2/scripts/utils.py:145–153  ·  view source on GitHub ↗

this is a TPU compatible version of tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd) where the lower right triangle contains 1s

(nd, ns, *, dtype)

Source from the content-addressed store, hash-verified

143
144
145def get_attention_mask(nd, ns, *, dtype):
146 """
147 this is a TPU compatible version of tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd)
148 where the lower right triangle contains 1s
149 """
150 i = tf.range(nd)[:, None]
151 j = tf.range(ns)
152 m = i >= j - ns + nd
153 return tf.cast(m, dtype)
154
155
156def get_assignment_map_from_checkpoint(tvars, init_checkpoint):

Callers 1

__init__Method · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected