Split a sequence of data sample in the first dimension. Args: allow_nonseq_value (bool): Whether allow non-sequential data in split operation. If True, non-sequential data will be copied for all split data samples. Otherwise, an error will be
(self,
allow_nonseq_value: bool = False)
| 321 | return stacked_data_sample |
| 322 | |
| 323 | def split(self, |
| 324 | allow_nonseq_value: bool = False) -> Sequence['DataSample']: |
| 325 | """Split a sequence of data sample in the first dimension. |
| 326 | |
| 327 | Args: |
| 328 | allow_nonseq_value (bool): Whether allow non-sequential data in |
| 329 | split operation. If True, non-sequential data will be copied |
| 330 | for all split data samples. Otherwise, an error will be |
| 331 | raised. Defaults to False. |
| 332 | |
| 333 | Returns: |
| 334 | Sequence[DataSample]: The list of data samples after splitting. |
| 335 | """ |
| 336 | # 1. split |
| 337 | data_sample_list = [DataSample() for _ in range(len(self))] |
| 338 | for k in self.all_keys(): |
| 339 | stacked_value = self.get(k) |
| 340 | if isinstance(stacked_value, torch.Tensor): |
| 341 | # split tensor shape like (N, *shape) to N (*shape) tensors |
| 342 | values = [v for v in stacked_value] |
| 343 | elif isinstance(stacked_value, LabelData): |
| 344 | # split tensor shape like (N, *shape) to N (*shape) tensors |
| 345 | labels = [l_ for l_ in stacked_value.label] |
| 346 | values = [LabelData(label=l_) for l_ in labels] |
| 347 | elif isinstance(stacked_value, DataSample): |
| 348 | values = stacked_value.split() |
| 349 | else: |
| 350 | if is_splitable_var(stacked_value): |
| 351 | values = stacked_value |
| 352 | elif allow_nonseq_value: |
| 353 | values = [deepcopy(stacked_value)] * len(self) |
| 354 | else: |
| 355 | raise TypeError( |
| 356 | f'\'{k}\' is non-sequential data and ' |
| 357 | '\'allow_nonseq_value\' is False. Please check your ' |
| 358 | 'data sample or set \'allow_nonseq_value\' as True ' |
| 359 | f'to copy field \'{k}\' for all split data sample.') |
| 360 | |
| 361 | field = 'metainfo' if k in self.metainfo_keys() else 'data' |
| 362 | for data, v in zip(data_sample_list, values): |
| 363 | data.set_field(v, k, field_type=field) |
| 364 | |
| 365 | return data_sample_list |
| 366 | |
| 367 | def __len__(self): |
| 368 | """Get the length of the data sample.""" |
no test coverage detected