MCPcopy Index your code
hub / github.com/pytorch/pytorch / _apply

Method _apply

caffe2/python/rnn_cell.py:605–700  ·  view source on GitHub ↗
(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
        extra_inputs=None,
    )

Source from the content-addressed store, hash-verified

603class MILSTMCell(LSTMCell):
604
605 def _apply(
606 self,
607 model,
608 input_t,
609 seq_lengths,
610 states,
611 timestep,
612 extra_inputs=None,
613 ):
614 hidden_t_prev, cell_t_prev = states
615
616 fc_input = hidden_t_prev
617 fc_input_dim = self.hidden_size
618
619 if extra_inputs is not None:
620 extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
621 fc_input = brew.concat(
622 model,
623 [hidden_t_prev] + list(extra_input_blobs),
624 self.scope('gates_concatenated_input_t'),
625 axis=2,
626 )
627 fc_input_dim += sum(extra_input_sizes)
628
629 prev_t = brew.fc(
630 model,
631 fc_input,
632 self.scope('prev_t'),
633 dim_in=fc_input_dim,
634 dim_out=self.gates_size,
635 axis=2,
636 )
637
638 # defining initializers for MI parameters
639 alpha = model.create_param(
640 self.scope('alpha'),
641 shape=[self.gates_size],
642 initializer=Initializer('ConstantFill', value=1.0),
643 )
644 beta_h = model.create_param(
645 self.scope('beta1'),
646 shape=[self.gates_size],
647 initializer=Initializer('ConstantFill', value=1.0),
648 )
649 beta_i = model.create_param(
650 self.scope('beta2'),
651 shape=[self.gates_size],
652 initializer=Initializer('ConstantFill', value=1.0),
653 )
654 b = model.create_param(
655 self.scope('b'),
656 shape=[self.gates_size],
657 initializer=Initializer('ConstantFill', value=0.0),
658 )
659
660 # alpha * input_t + beta_h
661 # Shape: [1, batch_size, 4 * hidden_size]
662 alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(

Callers

nothing calls this directly

Calls 8

InitializerClass · 0.90
listFunction · 0.85
concatMethod · 0.80
AddExternalOutputsMethod · 0.80
sumFunction · 0.50
scopeMethod · 0.45
create_paramMethod · 0.45
sumMethod · 0.45

Tested by

no test coverage detected