MCPcopy Index your code
hub / github.com/LargeWorldModel/LWM / TextProcessor

Class TextProcessor

lwm/data.py:55–123  ·  view source on GitHub ↗

Example processor that converts a dictionary of texts into tokens.

Source from the content-addressed store, hash-verified

53
54
55class TextProcessor(object):
56 """ Example processor that converts a dictionary of texts into tokens. """
57 @staticmethod
58 def get_default_config(updates=None):
59 config = ConfigDict()
60 config.fields_from_example = ''
61 config.fields = ''
62 config.subfield_separator = ' '
63 config.add_bos_token = True
64 config.add_eos_token = True
65 config.prepend_text = ''
66 if updates is not None:
67 config.update(ConfigDict(updates).copy_and_resolve_references())
68 return config
69
70 def __init__(self, config, tokenizer):
71 self.config = self.get_default_config(config)
72 assert self.config.fields != '' or self.config.fields_from_example != '', (
73 'Either fields or fields_from_example must be specified.'
74 )
75 self.tokenizer = tokenizer
76
77 def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True):
78 if has_aux:
79 example, *aux = example
80 else:
81 aux = tuple()
82 token_buffer = []
83 loss_mask_buffer = []
84
85 if add_bos_token and self.config.add_bos_token:
86 token_buffer.append(self.tokenizer.bos_token_id)
87 loss_mask_buffer.append(0.0)
88
89 if self.config.fields_from_example != '':
90 fields = example[self.config.fields_from_example].split(',')
91 else:
92 fields = self.config.fields.split(',')
93
94 for i, field in enumerate(fields):
95 if field.startswith('[') and field.endswith(']'):
96 # No loss for this field.
97 field = field[1:-1]
98 mask = 0.0
99 else:
100 mask = 1.0
101
102 if field == '<|bos|>':
103 token_buffer.append(self.tokenizer.bos_token_id)
104 loss_mask_buffer.append(mask)
105 elif field == '<|eos|>':
106 token_buffer.append(self.tokenizer.eos_token_id)
107 loss_mask_buffer.append(mask)
108 else:
109 subfields = field.split('+')
110 text = self.config.subfield_separator.join(
111 [example[subfield] for subfield in subfields]
112 )

Callers 1

load_datasetMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected