MCPcopy
hub / github.com/ladaapp/lada / stack

Method stack

lada/models/basicvsrpp/mmagic/data_sample.py:277–321  ·  view source on GitHub ↗

Stack a list of data samples to one. All tensor fields will be stacked at first dimension. Otherwise the values will be saved in a list. Args: data_samples (Sequence['DataSample']): A sequence of `DataSample` to stack. Returns:

(cls, data_samples: Sequence['DataSample'])

Source from the content-addressed store, hash-verified

275
276 @classmethod
277 def stack(cls, data_samples: Sequence['DataSample']) -> 'DataSample':
278 """Stack a list of data samples to one. All tensor fields will be
279 stacked at first dimension. Otherwise the values will be saved in a
280 list.
281
282 Args:
283 data_samples (Sequence['DataSample']): A sequence of
284 `DataSample` to stack.
285
286 Returns:
287 DataSample: The stacked data sample.
288 """
289 # 1. check key consistency
290 keys = data_samples[0].keys()
291 assert all([data.keys() == keys for data in data_samples])
292
293 meta_keys = data_samples[0].metainfo_keys()
294 assert all(
295 [data.metainfo_keys() == meta_keys for data in data_samples])
296
297 # 2. stack data
298 stacked_data_sample = DataSample()
299 for k in keys:
300 values = [getattr(data, k) for data in data_samples]
301 # 3. check type consistent
302 value_type = type(values[0])
303 assert all([type(val) == value_type for val in values])
304
305 # 4. stack
306 if isinstance(values[0], torch.Tensor):
307 stacked_value = torch.stack(values)
308 elif isinstance(values[0], LabelData):
309 labels = [data.label for data in values]
310 values = torch.stack(labels)
311 stacked_value = LabelData(label=values)
312 else:
313 stacked_value = values
314 stacked_data_sample.set_field(stacked_value, k)
315
316 # 5. stack metainfo
317 for k in meta_keys:
318 values = [data.metainfo[k] for data in data_samples]
319 stacked_data_sample.set_metainfo({k: values})
320
321 return stacked_data_sample
322
323 def split(self,
324 allow_nonseq_value: bool = False) -> Sequence['DataSample']:

Callers 15

inferenceFunction · 0.80
draw_mosaic_detectionsFunction · 0.80
restoreMethod · 0.80
_preprocess_cpuMethod · 0.80
_preprocess_gpuMethod · 0.80
__getitem__Method · 0.80
inferenceFunction · 0.80
all_to_tensorFunction · 0.80
gaussianMethod · 0.80
add_imageMethod · 0.80
add_imageMethod · 0.80

Calls 1

DataSampleClass · 0.85

Tested by

no test coverage detected