MCPcopy Index your code
hub / github.com/LargeWorldModel/LWM / VisionTextProcessor

Class VisionTextProcessor

lwm/data.py:126–239  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

124
125
126class VisionTextProcessor(object):
127 @staticmethod
128 def get_default_config(updates=None):
129 config = ConfigDict()
130 config.fields_from_example = ''
131 config.subfield_separator = ' '
132 config.add_bos_token = True
133 config.add_eos_token = True
134 config.prepend_text = ''
135 config.fields_index = -1
136 config.eof_token = 8192 # denotes end of each frame for video generation
137 config.eov_token = 8193 # denotes end of vision generation
138 config.n_tokens_per_frame = 256 # 16 x 16 VQ codes
139 config.max_n_frames = -1
140 if updates is not None:
141 config.update(ConfigDict(updates).copy_and_resolve_references())
142 return config
143
144 def __init__(self, config, tokenizer):
145 self.config = self.get_default_config(config)
146 assert self.config.fields_from_example != '', (
147 'fields_from_example must be specified.'
148 )
149 self.tokenizer = tokenizer
150 self.vision_start = tokenizer.encode('<vision>')
151 self.vision_end = tokenizer.encode('</vision>')
152
153 def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True):
154 if has_aux:
155 example, *aux = example
156 else:
157 aux = tuple()
158 rand_state = random.Random(aux[-1]) # makes augmentations deterministic by line number
159 token_buffer = []
160 loss_mask_buffer = []
161 vision_mask = []
162
163 fields = example[self.config.fields_from_example]
164 if isinstance(fields, (tuple, list)):
165 if self.config.fields_index >= 0:
166 fields = fields[self.config.fields_index]
167 else:
168 # seed based on line number
169 fields = rand_state.choice(fields)
170 fields = fields.split(',')
171
172 if add_bos_token and self.config.add_bos_token:
173 token_buffer.append(self.tokenizer.bos_token_id)
174 loss_mask_buffer.append(0.0)
175 vision_mask.append(False)
176
177 for i, field in enumerate(fields):
178 if field.startswith('[') and field.endswith(']'):
179 # No loss for this field.
180 field = field[1:-1]
181 mask = 0.0
182 else:
183 mask = 1.0

Callers 1

load_datasetMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected