MCPcopy
hub / github.com/OpenPPL/ppq / channel_random_fetch

Function channel_random_fetch

ppq/utils/fetch.py:53–80  ·  view source on GitHub ↗

Fetch some elements from tensor randomly by each channel. if a valid seed is given, elements will be sampled based on your seed, otherwise a random seed will be generated. result: [num_of_channel, fetchs_per_channel] Args: tensor (torch.Tensor): [description] fetchs

(
    tensor: torch.Tensor,
    fetchs_per_channel: int = 1024,
    seed: int = None,
    channel_axis: int = 0)

Source from the content-addressed store, hash-verified

51
52
53def channel_random_fetch(
54 tensor: torch.Tensor,
55 fetchs_per_channel: int = 1024,
56 seed: int = None,
57 channel_axis: int = 0) -> torch.Tensor:
58 """Fetch some elements from tensor randomly by each channel. if a valid
59 seed is given, elements will be sampled based on your seed, otherwise a
60 random seed will be generated.
61
62 result: [num_of_channel, fetchs_per_channel]
63
64 Args:
65 tensor (torch.Tensor): [description]
66 fetchs_per_channel (int, optional): [description]. Defaults to 1024.
67 channel_axis (int, optional): [description]. Defaults to 0.
68
69 Returns:
70 torch.Tensor: [description]
71 """
72 tensor = tensor.transpose(0, channel_axis)
73 tensor = tensor.flatten(start_dim=1)
74 num_of_elements = tensor.shape[-1]
75 assert num_of_elements > 0, ('Can not fetch data from empty tensor(0 element).')
76
77 if seed is None:
78 indexer = generate_torch_indexer(num_of_fetches=fetchs_per_channel, num_of_elements=num_of_elements)
79 else: indexer = generate_indexer(num_of_fetches=fetchs_per_channel, num_of_elements=num_of_elements, seed=seed)
80 return tensor.index_select(dim=-1, index=indexer.to(tensor.device).long())
81
82
83def batch_random_fetch(

Callers 1

observeMethod · 0.90

Calls 3

generate_torch_indexerFunction · 0.85
generate_indexerFunction · 0.85
toMethod · 0.80

Tested by

no test coverage detected