| 177 | |
| 178 | @no_grad |
| 179 | def _cws_post( |
| 180 | self, |
| 181 | result: TokenClassifierResult, |
| 182 | hidden: Dict[str, torch.Tensor], |
| 183 | store: Dict[str, Any], |
| 184 | inputs: List[str] = None, |
| 185 | tokenized: BatchEncoding = None, |
| 186 | ) -> LTPOutput: |
| 187 | crf = result.crf |
| 188 | logits = result.logits |
| 189 | attention_mask = result.attention_mask |
| 190 | |
| 191 | text = [] |
| 192 | char_idx = [] |
| 193 | for raw_text, encodings in zip(inputs, tokenized.encodings): |
| 194 | last = None |
| 195 | text.append([]) |
| 196 | char_idx.append([]) |
| 197 | for idx, current in enumerate(encodings.offsets[1:-1]): |
| 198 | if current == (0, 0): |
| 199 | break |
| 200 | elif current[0] == current[1]: |
| 201 | continue |
| 202 | elif current != last: |
| 203 | text[-1].append(raw_text[current[0] : current[1]]) |
| 204 | char_idx[-1].append(idx) |
| 205 | last = current |
| 206 | text = ["".join(t) for t in text] |
| 207 | |
| 208 | if crf is None: |
| 209 | decoded = logits.argmax(dim=-1) |
| 210 | decoded = decoded.cpu().numpy() |
| 211 | attention_mask = attention_mask.cpu().numpy() |
| 212 | |
| 213 | decoded = [ |
| 214 | [self.cws_vocab[tag] for tag, mask in zip(tags, masks) if mask] |
| 215 | for tags, masks in zip(decoded, attention_mask) |
| 216 | ] |
| 217 | else: |
| 218 | logits = torch.log_softmax(logits, dim=-1) |
| 219 | decoded = crf.decode(logits, attention_mask) |
| 220 | decoded = [[self.cws_vocab[tag] for tag in tags] for tags in decoded] |
| 221 | entities = [get_entities([d[i] for i in idx]) for d, idx in zip(decoded, char_idx)] |
| 222 | entities = [[(e[1], e[2]) for e in se] for se in entities] |
| 223 | |
| 224 | words = [ |
| 225 | [sent[e[0] : e[1] + 1] for e in sent_entities] |
| 226 | for sent, sent_entities in zip(text, entities) |
| 227 | ] |
| 228 | |
| 229 | if len(self.hook): |
| 230 | words = [self.hook.hook(t, w) for t, w in zip(text, words)] |
| 231 | entities = [] |
| 232 | for he in [np.cumsum([len(w) for w in s]) for s in words]: |
| 233 | entities.append([]) |
| 234 | for i, e in enumerate(he): |
| 235 | if i == 0: |
| 236 | entities[-1].append((0, e - 1)) |