MCPcopy
hub / github.com/Tencent/NeuralNLP-NeuralClassifier / run

Method run

train.py:108–185  ·  view source on GitHub ↗
(self, data_loader, model, optimizer, stage,
            epoch, mode=ModeType.EVAL)

Source from the content-addressed store, hash-verified

106 return self.run(data_loader, model, optimizer, stage, epoch)
107
108 def run(self, data_loader, model, optimizer, stage,
109 epoch, mode=ModeType.EVAL):
110 is_multi = False
111 # multi-label classifcation
112 if self.conf.task_info.label_type == ClassificationType.MULTI_LABEL:
113 is_multi = True
114 predict_probs = []
115 standard_labels = []
116 num_batch = data_loader.__len__()
117 total_loss = 0.
118 for batch in data_loader:
119 # hierarchical classification using hierarchy penalty loss
120 if self.conf.task_info.hierarchical:
121 logits = model(batch)
122 linear_paras = model.linear.weight
123 is_hierar = True
124 used_argvs = (self.conf.task_info.hierar_penalty, linear_paras, self.hierar_relations)
125 loss = self.loss_fn(
126 logits,
127 batch[ClassificationDataset.DOC_LABEL].to(self.conf.device),
128 is_hierar,
129 is_multi,
130 *used_argvs)
131 # hierarchical classification with HMCN
132 elif self.conf.model_name == "HMCN":
133 (global_logits, local_logits, logits) = model(batch)
134 loss = self.loss_fn(
135 global_logits,
136 batch[ClassificationDataset.DOC_LABEL].to(self.conf.device),
137 False,
138 is_multi)
139 loss += self.loss_fn(
140 local_logits,
141 batch[ClassificationDataset.DOC_LABEL].to(self.conf.device),
142 False,
143 is_multi)
144 # flat classificaiton
145 else:
146 logits = model(batch)
147 loss = self.loss_fn(
148 logits,
149 batch[ClassificationDataset.DOC_LABEL].to(self.conf.device),
150 False,
151 is_multi)
152 if mode == ModeType.TRAIN:
153 optimizer.zero_grad()
154 loss.backward()
155 optimizer.step()
156 continue
157 total_loss += loss.item()
158 if not is_multi:
159 result = torch.nn.functional.softmax(logits, dim=1).cpu().tolist()
160 else:
161 result = torch.sigmoid(logits).cpu().tolist()
162 predict_probs.extend(result)
163 standard_labels.extend(batch[ClassificationDataset.DOC_LABEL_LIST])
164 if mode == ModeType.EVAL:
165 total_loss = total_loss / num_batch

Callers 2

trainMethod · 0.95
evalMethod · 0.95

Calls 4

__len__Method · 0.80
stepMethod · 0.80
evaluateMethod · 0.80
warnMethod · 0.80

Tested by

no test coverage detected