MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / __init__

Method __init__

LanguageNetwork/GPT2/train/modeling.py:411–514  ·  view source on GitHub ↗

:param config: :param is_training: :param input_ids: Tensor thats of size [batch_size, seq_length] :param cache: Optionally, a tensor to use that will contain cached information of the size [batch_size, num_layers, 2, num_heads, cache_length, features]

(self,
                 config: GroverConfig,
                 is_training,
                 input_ids,
                 cache=None,
                 do_cache=False,
                 pad_token_id=0,
                 chop_off_last_token=True,
                 scope=None,
                 reuse=False)

Source from the content-addressed store, hash-verified

409
410class GroverModel(object):
411 def __init__(self,
412 config: GroverConfig,
413 is_training,
414 input_ids,
415 cache=None,
416 do_cache=False,
417 pad_token_id=0,
418 chop_off_last_token=True,
419 scope=None,
420 reuse=False):
421 """
422 :param config:
423 :param is_training:
424 :param input_ids: Tensor thats of size [batch_size, seq_length]
425 :param cache: Optionally, a tensor to use that will contain cached information of the size
426 [batch_size, num_layers, 2, num_heads, cache_length, features]
427 :param do_cache: Whether to cache again.
428 :param pad_token_id: Which token will be used for padding (probably 0.)
429 :param chop_off_last_token: True if we will end up using this for TRAINING only. False if we want to generate.
430 it means the last token in input_ids will not be processed by the model as input
431 :param scope: scope to run this on
432 """
433 self.config = copy.deepcopy(config)
434 self.is_training = is_training
435 self.pad_token_id = pad_token_id
436
437 if not is_training:
438 self.config.hidden_dropout_prob = 0.0
439 self.config.attention_probs_dropout_prob = 0.0
440
441 if chop_off_last_token:
442 self.target_ids = input_ids[:, 1:]
443 self.input_ids = input_ids[:, :-1]
444 else:
445 self.input_ids = input_ids
446 self.target_ids = tf.concat((input_ids[:, 1:],
447 tf.constant(self.pad_token_id, dtype=self.input_ids.dtype,
448 shape=[get_shape_list(self.input_ids, 2)[0], 1])), 1)
449
450 self.batch_size, self.seq_length = get_shape_list(self.input_ids, 2)
451
452 if cache is None:
453 caches = [None] * config.num_hidden_layers
454 self.cache_length = 0
455 else:
456 batch_size_, num_layers_, two_, num_heads_, self.cache_length, features_ = get_shape_list(
457 cache, expected_rank=6)
458 assert batch_size_ == self.batch_size
459 assert num_layers_ == config.num_hidden_layers
460 assert two_ == 2
461 assert num_heads_ == config.num_attention_heads
462 assert features_ == (config.hidden_size // config.num_attention_heads)
463 caches = tf.unstack(cache, axis=1)
464
465 with tf.variable_scope(scope, default_name='newslm', reuse=reuse):
466 with tf.variable_scope("embeddings"):
467 embeddings, self.embedding_table = embed(self.input_ids, config.vocab_size,
468 config.hidden_size,

Callers

nothing calls this directly

Calls 5

get_shape_listFunction · 0.90
get_attention_maskFunction · 0.90
embedFunction · 0.70
attention_layerFunction · 0.70
residual_mlp_layerFunction · 0.70

Tested by

no test coverage detected