Init state. Args: x: TODO.
(self, x)
| 1117 | |
| 1118 | # scorer interface methods |
| 1119 | def init_state(self, x): |
| 1120 | # to support mutiple encoder asr mode, in single encoder mode, |
| 1121 | # convert torch.Tensor to List of torch.Tensor |
| 1122 | """Init state. |
| 1123 | |
| 1124 | Args: |
| 1125 | x: TODO. |
| 1126 | """ |
| 1127 | if self.num_encs == 1: |
| 1128 | x = [x] |
| 1129 | |
| 1130 | c_list = [self.zero_state(x[0].unsqueeze(0))] |
| 1131 | z_list = [self.zero_state(x[0].unsqueeze(0))] |
| 1132 | for _ in six.moves.range(1, self.dlayers): |
| 1133 | c_list.append(self.zero_state(x[0].unsqueeze(0))) |
| 1134 | z_list.append(self.zero_state(x[0].unsqueeze(0))) |
| 1135 | # TODO(karita): support strm_index for `asr_mix` |
| 1136 | strm_index = 0 |
| 1137 | att_idx = min(strm_index, len(self.att) - 1) |
| 1138 | if self.num_encs == 1: |
| 1139 | a = None |
| 1140 | self.att[att_idx].reset() # reset pre-computation of h |
| 1141 | else: |
| 1142 | a = [None] * (self.num_encs + 1) # atts + han |
| 1143 | for idx in range(self.num_encs + 1): |
| 1144 | self.att[idx].reset() # reset pre-computation of h in atts and han |
| 1145 | return dict( |
| 1146 | c_prev=c_list[:], |
| 1147 | z_prev=z_list[:], |
| 1148 | a_prev=a, |
| 1149 | workspace=(att_idx, z_list, c_list), |
| 1150 | ) |
| 1151 | |
| 1152 | def score(self, yseq, state, x): |
| 1153 | # to support mutiple encoder asr mode, in single encoder mode, |
nothing calls this directly
no test coverage detected