| 1670 | self.ignore_index = ignore_index |
| 1671 | |
| 1672 | def __call__(self, data): |
| 1673 | text = data["label"] |
| 1674 | if self.ch: |
| 1675 | text, text_node_index, text_node_num = self.encodech(text) |
| 1676 | if text is None: |
| 1677 | return None |
| 1678 | if len(text) > self.max_text_len: |
| 1679 | return None |
| 1680 | data["length"] = np.array(len(text)) |
| 1681 | |
| 1682 | text_pos_node = [1] * (len(text) + 1) + [0] * ( |
| 1683 | self.max_text_len - len(text) |
| 1684 | ) |
| 1685 | |
| 1686 | text.append(0) # eos |
| 1687 | text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text)) |
| 1688 | |
| 1689 | data["label"] = np.array(text) |
| 1690 | data["label_node"] = np.array(text_node_num + text_pos_node) |
| 1691 | data["label_index"] = np.array(text_node_index) |
| 1692 | return data |
| 1693 | else: |
| 1694 | text, text_char_node, ch_order = self.encode(text) |
| 1695 | if text is None: |
| 1696 | return None |
| 1697 | if len(text) >= self.max_text_len: |
| 1698 | return None |
| 1699 | data["length"] = np.array(len(text)) |
| 1700 | |
| 1701 | text_pos_node = [1] * (len(text) + 1) + [0] * ( |
| 1702 | self.max_text_len - len(text) |
| 1703 | ) |
| 1704 | |
| 1705 | text.append(0) # eos |
| 1706 | |
| 1707 | text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text)) |
| 1708 | data["label"] = np.array(text) |
| 1709 | data["label_node"] = np.array(text_char_node + text_pos_node) |
| 1710 | data["label_order"] = np.array(ch_order) |
| 1711 | |
| 1712 | return data |
| 1713 | |
| 1714 | def add_special_char(self, dict_character): |
| 1715 | dict_character = ["</s>"] + dict_character |