Example processor that converts a dictionary of texts into tokens.
| 53 | |
| 54 | |
| 55 | class TextProcessor(object): |
| 56 | """ Example processor that converts a dictionary of texts into tokens. """ |
| 57 | @staticmethod |
| 58 | def get_default_config(updates=None): |
| 59 | config = ConfigDict() |
| 60 | config.fields_from_example = '' |
| 61 | config.fields = '' |
| 62 | config.subfield_separator = ' ' |
| 63 | config.add_bos_token = True |
| 64 | config.add_eos_token = True |
| 65 | config.prepend_text = '' |
| 66 | if updates is not None: |
| 67 | config.update(ConfigDict(updates).copy_and_resolve_references()) |
| 68 | return config |
| 69 | |
| 70 | def __init__(self, config, tokenizer): |
| 71 | self.config = self.get_default_config(config) |
| 72 | assert self.config.fields != '' or self.config.fields_from_example != '', ( |
| 73 | 'Either fields or fields_from_example must be specified.' |
| 74 | ) |
| 75 | self.tokenizer = tokenizer |
| 76 | |
| 77 | def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True): |
| 78 | if has_aux: |
| 79 | example, *aux = example |
| 80 | else: |
| 81 | aux = tuple() |
| 82 | token_buffer = [] |
| 83 | loss_mask_buffer = [] |
| 84 | |
| 85 | if add_bos_token and self.config.add_bos_token: |
| 86 | token_buffer.append(self.tokenizer.bos_token_id) |
| 87 | loss_mask_buffer.append(0.0) |
| 88 | |
| 89 | if self.config.fields_from_example != '': |
| 90 | fields = example[self.config.fields_from_example].split(',') |
| 91 | else: |
| 92 | fields = self.config.fields.split(',') |
| 93 | |
| 94 | for i, field in enumerate(fields): |
| 95 | if field.startswith('[') and field.endswith(']'): |
| 96 | # No loss for this field. |
| 97 | field = field[1:-1] |
| 98 | mask = 0.0 |
| 99 | else: |
| 100 | mask = 1.0 |
| 101 | |
| 102 | if field == '<|bos|>': |
| 103 | token_buffer.append(self.tokenizer.bos_token_id) |
| 104 | loss_mask_buffer.append(mask) |
| 105 | elif field == '<|eos|>': |
| 106 | token_buffer.append(self.tokenizer.eos_token_id) |
| 107 | loss_mask_buffer.append(mask) |
| 108 | else: |
| 109 | subfields = field.split('+') |
| 110 | text = self.config.subfield_separator.join( |
| 111 | [example[subfield] for subfield in subfields] |
| 112 | ) |