MCPcopy
hub / github.com/hunkim/PyTorchZeroToAll / RNN

Class RNN

13_3_char_rnn.py:16–45  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

14
15
16class RNN(nn.Module):
17
18 def __init__(self, input_size, hidden_size, output_size, n_layers=1):
19 super(RNN, self).__init__()
20 self.input_size = input_size
21 self.hidden_size = hidden_size
22 self.output_size = output_size
23 self.n_layers = n_layers
24
25 self.embedding = nn.Embedding(input_size, hidden_size)
26 self.gru = nn.GRU(hidden_size, hidden_size, n_layers)
27 self.linear = nn.Linear(hidden_size, output_size)
28
29 # This runs this one step at a time
30 # It's extremely slow, and please do not use in practice.
31 # We need to use (1) batch and (2) data parallelism
32 def forward(self, input, hidden):
33 embed = self.embedding(input.view(1, -1)) # S(=1) x I
34 embed = embed.view(1, 1, -1) # S(=1) x B(=1) x I (embedding size)
35 output, hidden = self.gru(embed, hidden)
36 output = self.linear(output.view(1, -1)) # S(=1) x I
37 return output, hidden
38
39 def init_hidden(self):
40 if torch.cuda.is_available():
41 hidden = torch.zeros(self.n_layers, 1, self.hidden_size).cuda()
42 else:
43 hidden = torch.zeros(self.n_layers, 1, self.hidden_size)
44
45 return Variable(hidden)
46
47
48def str2tensor(string):

Callers 1

13_3_char_rnn.pyFile · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected