Fetch some elements from each sample in a batched tensor. if a valid seed is given, elements will be sampled based on your seed, otherwise a random seed will be generated. result: [num_of_batch, fetchs_per_batch] Args: tensor (torch.Tensor): [description] fetchs_per
(
tensor: torch.Tensor,
fetches_per_batch: int = 1024,
seed: int = None
)
| 81 | |
| 82 | |
| 83 | def batch_random_fetch( |
| 84 | tensor: torch.Tensor, |
| 85 | fetches_per_batch: int = 1024, |
| 86 | seed: int = None |
| 87 | ) -> torch.Tensor: |
| 88 | """Fetch some elements from each sample in a batched tensor. if a valid |
| 89 | seed is given, elements will be sampled based on your seed, otherwise a |
| 90 | random seed will be generated. |
| 91 | |
| 92 | result: [num_of_batch, fetchs_per_batch] |
| 93 | |
| 94 | Args: |
| 95 | tensor (torch.Tensor): [description] |
| 96 | fetchs_per_channel (int, optional): [description]. Defaults to 1024. |
| 97 | |
| 98 | Returns: |
| 99 | torch.Tensor: [description] |
| 100 | """ |
| 101 | tensor = tensor.flatten(start_dim=1) |
| 102 | num_of_elements = tensor.shape[-1] |
| 103 | assert num_of_elements > 0, ('Can not fetch data from empty tensor(0 element).') |
| 104 | |
| 105 | if seed is None: |
| 106 | indexer = generate_torch_indexer(num_of_fetches=fetches_per_batch, num_of_elements=num_of_elements) |
| 107 | else: indexer = generate_indexer(num_of_fetches=fetches_per_batch, num_of_elements=num_of_elements, seed=seed) |
| 108 | return tensor.index_select(dim=-1, index=indexer.to(tensor.device).long()) |
no test coverage detected