MCPcopy
hub / github.com/hunkim/PyTorchZeroToAll / pad_sequences

Function pad_sequences

13_2_rnn_classification.py:49–67  ·  view source on GitHub ↗
(vectorized_seqs, seq_lengths, countries)

Source from the content-addressed store, hash-verified

47
48# pad sequences and sort the tensor
49def pad_sequences(vectorized_seqs, seq_lengths, countries):
50 seq_tensor = torch.zeros((len(vectorized_seqs), seq_lengths.max())).long()
51 for idx, (seq, seq_len) in enumerate(zip(vectorized_seqs, seq_lengths)):
52 seq_tensor[idx, :seq_len] = torch.LongTensor(seq)
53
54 # Sort tensors by their length
55 seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
56 seq_tensor = seq_tensor[perm_idx]
57
58 # Also sort the target (countries) in the same order
59 target = countries2tensor(countries)
60 if len(countries):
61 target = target[perm_idx]
62
63 # Return variables
64 # DataParallel requires everything to be a Variable
65 return create_variable(seq_tensor), \
66 create_variable(seq_lengths), \
67 create_variable(target)
68
69
70# Create necessary variables, lengths, and target

Callers 1

make_variablesFunction · 0.70

Calls 2

countries2tensorFunction · 0.85
create_variableFunction · 0.85

Tested by

no test coverage detected