MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / GroverModel

Class GroverModel

LanguageNetwork/GPT2/scripts/modeling.py:414–559  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

412
413
414class GroverModel(object):
415 def __init__(self,
416 config: GroverConfig,
417 is_training,
418 input_ids,
419 cache=None,
420 do_cache=False,
421 pad_token_id=0,
422 chop_off_last_token=True,
423 scope=None,
424 reuse=False):
425 """
426 :param config:
427 :param is_training:
428 :param input_ids: Tensor thats of size [batch_size, seq_length]
429 :param cache: Optionally, a tensor to use that will contain cached information of the size
430 [batch_size, num_layers, 2, num_heads, cache_length, features]
431 :param do_cache: Whether to cache again.
432 :param pad_token_id: Which token will be used for padding (probably 0.)
433 :param chop_off_last_token: True if we will end up using this for TRAINING only. False if we want to generate.
434 it means the last token in input_ids will not be processed by the model as input
435 :param scope: scope to run this on
436 """
437 self.config = copy.deepcopy(config)
438 self.is_training = is_training
439 self.pad_token_id = pad_token_id
440
441 if not is_training:
442 self.config.hidden_dropout_prob = 0.0
443 self.config.attention_probs_dropout_prob = 0.0
444
445 if chop_off_last_token:
446 self.target_ids = input_ids[:, 1:]
447 self.input_ids = input_ids[:, :-1]
448 else:
449 self.input_ids = input_ids
450 self.target_ids = tf.concat((input_ids[:, 1:],
451 tf.constant(self.pad_token_id, dtype=self.input_ids.dtype,
452 shape=[get_shape_list(self.input_ids, 2)[0], 1])), 1)
453
454 self.batch_size, self.seq_length = get_shape_list(self.input_ids, 2)
455
456 if cache is None:
457 caches = [None] * config.num_hidden_layers
458 self.cache_length = 0
459 else:
460 batch_size_, num_layers_, two_, num_heads_, self.cache_length, features_ = get_shape_list(
461 cache, expected_rank=6)
462 assert batch_size_ == self.batch_size
463 assert num_layers_ == config.num_hidden_layers
464 assert two_ == 2
465 assert num_heads_ == config.num_attention_heads
466 assert features_ == (config.hidden_size // config.num_attention_heads)
467 caches = tf.unstack(cache, axis=1)
468
469 with tf.compat.v1.variable_scope(scope, default_name='newslm', reuse=reuse):
470 with tf.compat.v1.variable_scope("embeddings"):
471 embeddings, self.embedding_table = embed(self.input_ids, config.vocab_size,

Callers 2

model_fnFunction · 0.70
sample_stepFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected