MCPcopy Index your code
hub / github.com/OpenPPL/ppq / FinetuneCheckPoint

Class FinetuneCheckPoint

ppq/quantization/algorithm/training.py:93–125  ·  view source on GitHub ↗

Finetune Check Point stores training loss for variables. It bounds to a specific variable, collects and stores its fp32 value as a reference. ATTENTION: collecting fp32 value might cause GPU memory overflow, so we use a seed to sample only a part of fp32 value instead(randomly pick

Source from the content-addressed store, hash-verified

91
92
93class FinetuneCheckPoint:
94 """Finetune Check Point stores training loss for variables. It bounds to a
95 specific variable, collects and stores its fp32 value as a reference.
96
97 ATTENTION: collecting fp32 value might cause GPU memory overflow, so we use a seed to
98 sample only a part of fp32 value instead(randomly pick about 2000 values from given tensor).
99
100 Finetune Check Point maintains a seed for data collecting, a best loss, and a reference values.
101 """
102 def __init__(self, variable: str, random_fetch: bool = True, seed: int=None, fetchs: int=NUM_OF_CHECKPOINT_FETCHS) -> None:
103 if seed is None: seed = randint(0, 0xffffffff)
104 self.monitor_var = variable
105 self.best_loss = float(1e9)
106 self.seed = seed
107 self.references = []
108 self.outputs = []
109 self.fetchs = fetchs
110 self.random_fetch = random_fetch
111
112 def push(self, tensor: torch.Tensor, is_reference: bool) -> None:
113 if self.random_fetch:
114 tensor = batch_random_fetch(tensor, seed=self.seed, fetches_per_batch=self.fetchs)
115 if is_reference: self.references.append(tensor)
116 else: self.outputs.append(tensor)
117
118 def pop(self) -> Tuple[torch.Tensor]:
119 assert len(self.outputs) == len(self.references), ('Inconsistent samples detected.'
120 f'Reference output gets {len(self.references)} samples, however output has {len(self.outputs)}.')
121
122 return self.outputs, self.references
123
124 def clear(self):
125 self.outputs.clear()
126
127
128class RandomMemDataset:

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected