(
graph: BaseGraph,
dataloader: Iterable,
interested_outputs: Union[str, List[str]],
collate_fn: Callable = None,
running_device = 'cuda',
samples_per_step: int = 65536,
steps: int = 8,
dequantize: bool = False)
| 135 | |
| 136 | |
| 137 | def variable_analyse( |
| 138 | graph: BaseGraph, |
| 139 | dataloader: Iterable, |
| 140 | interested_outputs: Union[str, List[str]], |
| 141 | collate_fn: Callable = None, |
| 142 | running_device = 'cuda', |
| 143 | samples_per_step: int = 65536, |
| 144 | steps: int = 8, |
| 145 | dequantize: bool = False): |
| 146 | |
| 147 | quant_graph = QuantableGraph(graph) |
| 148 | |
| 149 | executor = TorchExecutor(graph=graph, device=running_device) |
| 150 | if dequantize: quant_graph.dequantize_graph() |
| 151 | |
| 152 | data_collector = defaultdict(list) |
| 153 | for idx, batch in enumerate(dataloader): |
| 154 | if collate_fn is not None: batch = collate_fn(batch) |
| 155 | fp_outputs = executor.forward(inputs=batch, output_names=interested_outputs) |
| 156 | for output, output_name in zip(fp_outputs, interested_outputs): |
| 157 | data_collector[output_name].append( |
| 158 | tensor_random_fetch(tensor=output, num_of_fetches=samples_per_step).unsqueeze(0) |
| 159 | ) |
| 160 | if idx >= steps: break |
| 161 | |
| 162 | for name in interested_outputs: |
| 163 | tensor = torch.cat(data_collector[name]).flatten() |
| 164 | tensor = convert_any_to_numpy(tensor) |
| 165 | |
| 166 | try: |
| 167 | from matplotlib import pyplot as plt |
| 168 | except ImportError as e: |
| 169 | raise Exception('Install matplotlib before using this function.') |
| 170 | |
| 171 | plt.figure(figsize=[12, 8]) |
| 172 | plt.title(f'Histogram Result of Variable {name}:') |
| 173 | plt.hist(tensor, bins=64) |
| 174 | plt.show() |
| 175 | |
| 176 | if dequantize: quant_graph.restore_quantize_state() |
| 177 | |
| 178 | |
| 179 | def parameter_analyse(graph: BaseGraph): |
nothing calls this directly
no test coverage detected