(vectorized_seqs, seq_lengths, countries)
| 47 | |
| 48 | # pad sequences and sort the tensor |
| 49 | def pad_sequences(vectorized_seqs, seq_lengths, countries): |
| 50 | seq_tensor = torch.zeros((len(vectorized_seqs), seq_lengths.max())).long() |
| 51 | for idx, (seq, seq_len) in enumerate(zip(vectorized_seqs, seq_lengths)): |
| 52 | seq_tensor[idx, :seq_len] = torch.LongTensor(seq) |
| 53 | |
| 54 | # Sort tensors by their length |
| 55 | seq_lengths, perm_idx = seq_lengths.sort(0, descending=True) |
| 56 | seq_tensor = seq_tensor[perm_idx] |
| 57 | |
| 58 | # Also sort the target (countries) in the same order |
| 59 | target = countries2tensor(countries) |
| 60 | if len(countries): |
| 61 | target = target[perm_idx] |
| 62 | |
| 63 | # Return variables |
| 64 | # DataParallel requires everything to be a Variable |
| 65 | return create_variable(seq_tensor), \ |
| 66 | create_variable(seq_lengths), \ |
| 67 | create_variable(target) |
| 68 | |
| 69 | |
| 70 | # Create necessary variables, lengths, and target |
no test coverage detected