Fetch some elements from tensor randomly. if a valid seed is given, elements will be sampled based on your seed, otherwise a random seed will be generated. Args: tensor (torch.Tensor): [description] num_of_fetches (int, optional): [description]. Defaults to 1024.
(
tensor: torch.Tensor, seed: int = None,
num_of_fetches: int = 1024)
| 30 | |
| 31 | |
| 32 | def tensor_random_fetch( |
| 33 | tensor: torch.Tensor, seed: int = None, |
| 34 | num_of_fetches: int = 1024) -> torch.Tensor: |
| 35 | """Fetch some elements from tensor randomly. if a valid seed is given, |
| 36 | elements will be sampled based on your seed, otherwise a random seed will |
| 37 | be generated. |
| 38 | |
| 39 | Args: |
| 40 | tensor (torch.Tensor): [description] |
| 41 | num_of_fetches (int, optional): [description]. Defaults to 1024. |
| 42 | """ |
| 43 | tensor = tensor.flatten() |
| 44 | num_of_elements = tensor.numel() |
| 45 | assert num_of_elements > 0, ('Can not fetch data from empty tensor(0 element).') |
| 46 | |
| 47 | if seed is None: |
| 48 | indexer = generate_torch_indexer(num_of_fetches=num_of_fetches, num_of_elements=num_of_elements) |
| 49 | else: indexer = generate_indexer(num_of_fetches=num_of_fetches, num_of_elements=num_of_elements, seed=seed) |
| 50 | return tensor.index_select(dim=0, index=indexer.to(tensor.device).long()) |
| 51 | |
| 52 | |
| 53 | def channel_random_fetch( |
no test coverage detected