| 223 | |
| 224 | |
| 225 | def print_output(tokenizer, |
| 226 | output_ids: torch.Tensor, |
| 227 | input_lengths: List[int], |
| 228 | sequence_lengths: torch.Tensor, |
| 229 | output_csv: Optional[str] = None, |
| 230 | output_npy: Optional[str] = None, |
| 231 | context_logits: Optional[torch.Tensor] = None, |
| 232 | generation_logits: Optional[torch.Tensor] = None, |
| 233 | cum_log_probs: Optional[torch.Tensor] = None, |
| 234 | log_probs: Optional[torch.Tensor] = None, |
| 235 | output_logits_npy: Optional[str] = None, |
| 236 | output_cum_log_probs_npy: Optional[str] = None, |
| 237 | output_log_probs_npy: Optional[str] = None): |
| 238 | num_output_sents, num_beams, _ = output_ids.size() |
| 239 | batch_size = len(input_lengths) |
| 240 | num_return_sequences = num_output_sents // batch_size |
| 241 | |
| 242 | if output_csv is None and output_npy is None and tokenizer is not None: |
| 243 | for i in range(batch_size * num_return_sequences): |
| 244 | batch_idx = i // num_return_sequences |
| 245 | seq_idx = i % num_return_sequences |
| 246 | inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist() |
| 247 | input_text = tokenizer.decode(inputs) |
| 248 | if seq_idx == 0: |
| 249 | print(f'Input [Text {batch_idx}]: \"{input_text}\"') |
| 250 | |
| 251 | for beam in range(num_beams): |
| 252 | output_begin = input_lengths[batch_idx] |
| 253 | output_end = sequence_lengths[i][beam] |
| 254 | outputs = output_ids[i][beam][output_begin:output_end].tolist() |
| 255 | output_text = tokenizer.decode(outputs) |
| 256 | index_str = (f'Text {batch_idx} Seq {seq_idx} Beam {beam}' |
| 257 | if num_return_sequences > 1 else |
| 258 | f'Text {batch_idx} Beam {beam}') |
| 259 | print(f'Output [{index_str}]: \"{output_text}\"') |
| 260 | logger.debug(str(outputs)) |
| 261 | |
| 262 | output_ids = output_ids.reshape((-1, output_ids.size(2))) |
| 263 | |
| 264 | if output_csv is not None: |
| 265 | output_file = Path(output_csv) |
| 266 | output_file.parent.mkdir(exist_ok=True, parents=True) |
| 267 | outputs = output_ids.tolist() |
| 268 | with open(output_file, 'w') as csv_file: |
| 269 | writer = csv.writer(csv_file, delimiter=',') |
| 270 | writer.writerows(outputs) |
| 271 | |
| 272 | if output_npy is not None: |
| 273 | output_file = Path(output_npy) |
| 274 | output_file.parent.mkdir(exist_ok=True, parents=True) |
| 275 | outputs = np.array(output_ids.cpu().contiguous(), dtype='int32') |
| 276 | np.save(output_file, outputs) |
| 277 | |
| 278 | # Save context logits |
| 279 | if context_logits is not None and output_logits_npy is not None: |
| 280 | context_logits = torch.cat(context_logits, axis=0) |
| 281 | vocab_size_padded = context_logits.shape[-1] |
| 282 | context_logits = context_logits.reshape([1, -1, vocab_size_padded]) |