MCPcopy
hub / github.com/ladaapp/lada / split

Method split

lada/models/basicvsrpp/mmagic/data_sample.py:323–365  ·  view source on GitHub ↗

Split a sequence of data sample in the first dimension. Args: allow_nonseq_value (bool): Whether allow non-sequential data in split operation. If True, non-sequential data will be copied for all split data samples. Otherwise, an error will be

(self,
              allow_nonseq_value: bool = False)

Source from the content-addressed store, hash-verified

321 return stacked_data_sample
322
323 def split(self,
324 allow_nonseq_value: bool = False) -> Sequence['DataSample']:
325 """Split a sequence of data sample in the first dimension.
326
327 Args:
328 allow_nonseq_value (bool): Whether allow non-sequential data in
329 split operation. If True, non-sequential data will be copied
330 for all split data samples. Otherwise, an error will be
331 raised. Defaults to False.
332
333 Returns:
334 Sequence[DataSample]: The list of data samples after splitting.
335 """
336 # 1. split
337 data_sample_list = [DataSample() for _ in range(len(self))]
338 for k in self.all_keys():
339 stacked_value = self.get(k)
340 if isinstance(stacked_value, torch.Tensor):
341 # split tensor shape like (N, *shape) to N (*shape) tensors
342 values = [v for v in stacked_value]
343 elif isinstance(stacked_value, LabelData):
344 # split tensor shape like (N, *shape) to N (*shape) tensors
345 labels = [l_ for l_ in stacked_value.label]
346 values = [LabelData(label=l_) for l_ in labels]
347 elif isinstance(stacked_value, DataSample):
348 values = stacked_value.split()
349 else:
350 if is_splitable_var(stacked_value):
351 values = stacked_value
352 elif allow_nonseq_value:
353 values = [deepcopy(stacked_value)] * len(self)
354 else:
355 raise TypeError(
356 f'\'{k}\' is non-sequential data and '
357 '\'allow_nonseq_value\' is False. Please check your '
358 'data sample or set \'allow_nonseq_value\' as True '
359 f'to copy field \'{k}\' for all split data sample.')
360
361 field = 'metainfo' if k in self.metainfo_keys() else 'data'
362 for data, v in zip(data_sample_list, values):
363 data.set_field(v, k, field_type=field)
364
365 return data_sample_list
366
367 def __len__(self):
368 """Get the length of the data sample."""

Callers 15

load_fontsFunction · 0.80
convert_to_yoloFunction · 0.80
get_video_meta_dataFunction · 0.80
get_audio_codecFunction · 0.80
device_to_gpu_idFunction · 0.80
_update_env_varFunction · 0.80
_init_envMethod · 0.80

Calls 3

DataSampleClass · 0.85
is_splitable_varFunction · 0.85
getMethod · 0.45

Tested by

no test coverage detected