Stack a list of data samples to one. All tensor fields will be stacked at first dimension. Otherwise the values will be saved in a list. Args: data_samples (Sequence['DataSample']): A sequence of `DataSample` to stack. Returns:
(cls, data_samples: Sequence['DataSample'])
| 275 | |
| 276 | @classmethod |
| 277 | def stack(cls, data_samples: Sequence['DataSample']) -> 'DataSample': |
| 278 | """Stack a list of data samples to one. All tensor fields will be |
| 279 | stacked at first dimension. Otherwise the values will be saved in a |
| 280 | list. |
| 281 | |
| 282 | Args: |
| 283 | data_samples (Sequence['DataSample']): A sequence of |
| 284 | `DataSample` to stack. |
| 285 | |
| 286 | Returns: |
| 287 | DataSample: The stacked data sample. |
| 288 | """ |
| 289 | # 1. check key consistency |
| 290 | keys = data_samples[0].keys() |
| 291 | assert all([data.keys() == keys for data in data_samples]) |
| 292 | |
| 293 | meta_keys = data_samples[0].metainfo_keys() |
| 294 | assert all( |
| 295 | [data.metainfo_keys() == meta_keys for data in data_samples]) |
| 296 | |
| 297 | # 2. stack data |
| 298 | stacked_data_sample = DataSample() |
| 299 | for k in keys: |
| 300 | values = [getattr(data, k) for data in data_samples] |
| 301 | # 3. check type consistent |
| 302 | value_type = type(values[0]) |
| 303 | assert all([type(val) == value_type for val in values]) |
| 304 | |
| 305 | # 4. stack |
| 306 | if isinstance(values[0], torch.Tensor): |
| 307 | stacked_value = torch.stack(values) |
| 308 | elif isinstance(values[0], LabelData): |
| 309 | labels = [data.label for data in values] |
| 310 | values = torch.stack(labels) |
| 311 | stacked_value = LabelData(label=values) |
| 312 | else: |
| 313 | stacked_value = values |
| 314 | stacked_data_sample.set_field(stacked_value, k) |
| 315 | |
| 316 | # 5. stack metainfo |
| 317 | for k in meta_keys: |
| 318 | values = [data.metainfo[k] for data in data_samples] |
| 319 | stacked_data_sample.set_metainfo({k: values}) |
| 320 | |
| 321 | return stacked_data_sample |
| 322 | |
| 323 | def split(self, |
| 324 | allow_nonseq_value: bool = False) -> Sequence['DataSample']: |
no test coverage detected