Fetch some elements from tensor randomly by each channel. if a valid seed is given, elements will be sampled based on your seed, otherwise a random seed will be generated. result: [num_of_channel, fetchs_per_channel] Args: tensor (torch.Tensor): [description] fetchs
(
tensor: torch.Tensor,
fetchs_per_channel: int = 1024,
seed: int = None,
channel_axis: int = 0)
| 51 | |
| 52 | |
| 53 | def channel_random_fetch( |
| 54 | tensor: torch.Tensor, |
| 55 | fetchs_per_channel: int = 1024, |
| 56 | seed: int = None, |
| 57 | channel_axis: int = 0) -> torch.Tensor: |
| 58 | """Fetch some elements from tensor randomly by each channel. if a valid |
| 59 | seed is given, elements will be sampled based on your seed, otherwise a |
| 60 | random seed will be generated. |
| 61 | |
| 62 | result: [num_of_channel, fetchs_per_channel] |
| 63 | |
| 64 | Args: |
| 65 | tensor (torch.Tensor): [description] |
| 66 | fetchs_per_channel (int, optional): [description]. Defaults to 1024. |
| 67 | channel_axis (int, optional): [description]. Defaults to 0. |
| 68 | |
| 69 | Returns: |
| 70 | torch.Tensor: [description] |
| 71 | """ |
| 72 | tensor = tensor.transpose(0, channel_axis) |
| 73 | tensor = tensor.flatten(start_dim=1) |
| 74 | num_of_elements = tensor.shape[-1] |
| 75 | assert num_of_elements > 0, ('Can not fetch data from empty tensor(0 element).') |
| 76 | |
| 77 | if seed is None: |
| 78 | indexer = generate_torch_indexer(num_of_fetches=fetchs_per_channel, num_of_elements=num_of_elements) |
| 79 | else: indexer = generate_indexer(num_of_fetches=fetchs_per_channel, num_of_elements=num_of_elements, seed=seed) |
| 80 | return tensor.index_select(dim=-1, index=indexer.to(tensor.device).long()) |
| 81 | |
| 82 | |
| 83 | def batch_random_fetch( |
no test coverage detected