| 469 | |
| 470 | |
| 471 | def debug_finetune_data(local_vars, batch_id, tokenizer): |
| 472 | tokens, target_ids = local_vars["tokens"], local_vars["target_ids"] |
| 473 | attention_mask, logit_mask, position_ids = local_vars["attention_mask"], local_vars["logit_mask"], local_vars[ |
| 474 | "position_ids"] |
| 475 | output_tokens = [] |
| 476 | sep = attention_mask[batch_id].item() |
| 477 | for i, token in enumerate(tokens[batch_id][:sep].tolist()): |
| 478 | token = tokenizer.IdToToken(token) |
| 479 | if token == '[MASK]': |
| 480 | token = f"[{position_ids[batch_id][0, i].item()}]" |
| 481 | output_tokens.append(token) |
| 482 | print(" ".join(output_tokens)) |
| 483 | target_positions = [] |
| 484 | for i in range(sep, tokens.size(-1)): |
| 485 | if logit_mask[batch_id][i]: |
| 486 | target_positions.append(i) |
| 487 | print(target_positions) |
| 488 | print(tokenizer.DecodeIds(tokens[batch_id][target_positions].tolist())) |
| 489 | if len(target_ids.shape) > 2: |
| 490 | print(tokenizer.DecodeIds(target_ids[batch_id][target_positions].tolist())) |
| 491 | else: |
| 492 | print(tokenizer.DecodeIds(target_ids[batch_id].tolist())) |
| 493 | print(position_ids[batch_id][:, target_positions]) |