| 82 | ) |
| 83 | |
| 84 | def step(self, task_name: str, batch: Any): |
| 85 | inout_ids = batch["input_ids"] |
| 86 | attention_mask = batch["attention_mask"] |
| 87 | |
| 88 | # optional |
| 89 | token_type_ids = batch.get("token_type_ids", None) |
| 90 | |
| 91 | # optional for word-level task |
| 92 | word_index = batch.get("word_index", None) |
| 93 | word_attention_mask = batch.get("word_attention_mask", None) |
| 94 | |
| 95 | outputs = self.forward( |
| 96 | task_name, |
| 97 | inout_ids, |
| 98 | attention_mask, |
| 99 | token_type_ids, |
| 100 | word_index, |
| 101 | word_attention_mask, |
| 102 | ) |
| 103 | loss = self.criterions[task_name](outputs, **batch) |
| 104 | return loss, outputs |
| 105 | |
| 106 | def training_step(self, batch: Any, batch_idx: int): |
| 107 | task_name = batch["task_name"] |