(features, att_mode, index)
| 69 | |
| 70 | |
| 71 | def get_audio_features(features, att_mode, index): |
| 72 | if att_mode == 0: |
| 73 | return features[[index]] |
| 74 | elif att_mode == 1: |
| 75 | print(hparams['smo_win_size']) |
| 76 | left = index - hparams['smo_win_size'] |
| 77 | pad_left = 0 |
| 78 | if left < 0: |
| 79 | pad_left = -left |
| 80 | left = 0 |
| 81 | auds = features[left:index] |
| 82 | if pad_left > 0: |
| 83 | # pad may be longer than auds, so do not use zeros_like |
| 84 | auds = torch.cat([torch.zeros(pad_left, *auds.shape[1:], device=auds.device, dtype=auds.dtype), auds], dim=0) |
| 85 | return auds |
| 86 | elif att_mode == 2: |
| 87 | left = index - hparams['smo_win_size']//2 |
| 88 | right = index + (hparams['smo_win_size']-hparams['smo_win_size']//2) |
| 89 | pad_left = 0 |
| 90 | pad_right = 0 |
| 91 | if left < 0: |
| 92 | pad_left = -left |
| 93 | left = 0 |
| 94 | if right > features.shape[0]: |
| 95 | pad_right = right - features.shape[0] |
| 96 | right = features.shape[0] |
| 97 | auds = features[left:right] |
| 98 | if pad_left > 0: |
| 99 | auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0) |
| 100 | if pad_right > 0: |
| 101 | auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16] |
| 102 | return auds |
| 103 | else: |
| 104 | raise NotImplementedError(f'wrong att_mode: {att_mode}') |
| 105 | |
| 106 | |
| 107 | @torch.jit.script |
no outgoing calls
no test coverage detected