| 106 | return self.run(data_loader, model, optimizer, stage, epoch) |
| 107 | |
| 108 | def run(self, data_loader, model, optimizer, stage, |
| 109 | epoch, mode=ModeType.EVAL): |
| 110 | is_multi = False |
| 111 | # multi-label classifcation |
| 112 | if self.conf.task_info.label_type == ClassificationType.MULTI_LABEL: |
| 113 | is_multi = True |
| 114 | predict_probs = [] |
| 115 | standard_labels = [] |
| 116 | num_batch = data_loader.__len__() |
| 117 | total_loss = 0. |
| 118 | for batch in data_loader: |
| 119 | # hierarchical classification using hierarchy penalty loss |
| 120 | if self.conf.task_info.hierarchical: |
| 121 | logits = model(batch) |
| 122 | linear_paras = model.linear.weight |
| 123 | is_hierar = True |
| 124 | used_argvs = (self.conf.task_info.hierar_penalty, linear_paras, self.hierar_relations) |
| 125 | loss = self.loss_fn( |
| 126 | logits, |
| 127 | batch[ClassificationDataset.DOC_LABEL].to(self.conf.device), |
| 128 | is_hierar, |
| 129 | is_multi, |
| 130 | *used_argvs) |
| 131 | # hierarchical classification with HMCN |
| 132 | elif self.conf.model_name == "HMCN": |
| 133 | (global_logits, local_logits, logits) = model(batch) |
| 134 | loss = self.loss_fn( |
| 135 | global_logits, |
| 136 | batch[ClassificationDataset.DOC_LABEL].to(self.conf.device), |
| 137 | False, |
| 138 | is_multi) |
| 139 | loss += self.loss_fn( |
| 140 | local_logits, |
| 141 | batch[ClassificationDataset.DOC_LABEL].to(self.conf.device), |
| 142 | False, |
| 143 | is_multi) |
| 144 | # flat classificaiton |
| 145 | else: |
| 146 | logits = model(batch) |
| 147 | loss = self.loss_fn( |
| 148 | logits, |
| 149 | batch[ClassificationDataset.DOC_LABEL].to(self.conf.device), |
| 150 | False, |
| 151 | is_multi) |
| 152 | if mode == ModeType.TRAIN: |
| 153 | optimizer.zero_grad() |
| 154 | loss.backward() |
| 155 | optimizer.step() |
| 156 | continue |
| 157 | total_loss += loss.item() |
| 158 | if not is_multi: |
| 159 | result = torch.nn.functional.softmax(logits, dim=1).cpu().tolist() |
| 160 | else: |
| 161 | result = torch.sigmoid(logits).cpu().tolist() |
| 162 | predict_probs.extend(result) |
| 163 | standard_labels.extend(batch[ClassificationDataset.DOC_LABEL_LIST]) |
| 164 | if mode == ModeType.EVAL: |
| 165 | total_loss = total_loss / num_batch |