MCPcopy
hub / github.com/OpenPPL/ppq / batch_random_fetch

Function batch_random_fetch

ppq/utils/fetch.py:83–108  ·  view source on GitHub ↗

Fetch some elements from each sample in a batched tensor. if a valid seed is given, elements will be sampled based on your seed, otherwise a random seed will be generated. result: [num_of_batch, fetchs_per_batch] Args: tensor (torch.Tensor): [description] fetchs_per

(
    tensor: torch.Tensor,
    fetches_per_batch: int = 1024,
    seed: int = None
    )

Source from the content-addressed store, hash-verified

81
82
83def batch_random_fetch(
84 tensor: torch.Tensor,
85 fetches_per_batch: int = 1024,
86 seed: int = None
87 ) -> torch.Tensor:
88 """Fetch some elements from each sample in a batched tensor. if a valid
89 seed is given, elements will be sampled based on your seed, otherwise a
90 random seed will be generated.
91
92 result: [num_of_batch, fetchs_per_batch]
93
94 Args:
95 tensor (torch.Tensor): [description]
96 fetchs_per_channel (int, optional): [description]. Defaults to 1024.
97
98 Returns:
99 torch.Tensor: [description]
100 """
101 tensor = tensor.flatten(start_dim=1)
102 num_of_elements = tensor.shape[-1]
103 assert num_of_elements > 0, ('Can not fetch data from empty tensor(0 element).')
104
105 if seed is None:
106 indexer = generate_torch_indexer(num_of_fetches=fetches_per_batch, num_of_elements=num_of_elements)
107 else: indexer = generate_indexer(num_of_fetches=fetches_per_batch, num_of_elements=num_of_elements, seed=seed)
108 return tensor.index_select(dim=-1, index=indexer.to(tensor.device).long())

Callers 2

post_forward_hookMethod · 0.90
pushMethod · 0.90

Calls 3

generate_torch_indexerFunction · 0.85
generate_indexerFunction · 0.85
toMethod · 0.80

Tested by

no test coverage detected