(self, y: Tensor, mask: Optional[Tensor] = None)
| 1131 | return (T, H, W) |
| 1132 | |
| 1133 | def encode_text(self, y: Tensor, mask: Optional[Tensor] = None): |
| 1134 | y = self.y_embedder(y) # [B, 1, N_token, C] |
| 1135 | if mask is not None: |
| 1136 | if mask.shape[0] != y.shape[0]: |
| 1137 | mask = repeat(mask, sizes=(y.shape[0] // mask.shape[0], 1)) |
| 1138 | mask = squeeze(squeeze(mask, 1), 1) |
| 1139 | y = masked_select( |
| 1140 | squeeze(y, 1), |
| 1141 | where( |
| 1142 | unsqueeze(mask, -1).__eq__( |
| 1143 | constant_to_tensor_(0, dtype=mask.dtype)), |
| 1144 | constant_to_tensor_(False), |
| 1145 | constant_to_tensor_(True))).view((1, -1, self.hidden_size)) |
| 1146 | # [TODO] how to convert y_lens to list? |
| 1147 | # y_lens = mask.sum(dim=1).tolist() |
| 1148 | y_lens = sum(mask, dim=1) |
| 1149 | else: |
| 1150 | y_lens = constant( |
| 1151 | np.array([y.shape[2]] * y.shape[0], dtype=np.int64)) |
| 1152 | y = squeeze(y, 1).view((1, -1, self.hidden_size)) |
| 1153 | self.register_network_output('encode_text.output.y', y) |
| 1154 | self.register_network_output('encode_text.output.y_lens', y_lens) |
| 1155 | return y, y_lens |
| 1156 | |
| 1157 | def unpatchify(self, x: Tensor, N_t: int, N_h: int, N_w: int, R_t: int, |
| 1158 | R_h: int, R_w: int): |
no test coverage detected