Finetune Check Point stores training loss for variables. It bounds to a specific variable, collects and stores its fp32 value as a reference. ATTENTION: collecting fp32 value might cause GPU memory overflow, so we use a seed to sample only a part of fp32 value instead(randomly pick
| 91 | |
| 92 | |
| 93 | class FinetuneCheckPoint: |
| 94 | """Finetune Check Point stores training loss for variables. It bounds to a |
| 95 | specific variable, collects and stores its fp32 value as a reference. |
| 96 | |
| 97 | ATTENTION: collecting fp32 value might cause GPU memory overflow, so we use a seed to |
| 98 | sample only a part of fp32 value instead(randomly pick about 2000 values from given tensor). |
| 99 | |
| 100 | Finetune Check Point maintains a seed for data collecting, a best loss, and a reference values. |
| 101 | """ |
| 102 | def __init__(self, variable: str, random_fetch: bool = True, seed: int=None, fetchs: int=NUM_OF_CHECKPOINT_FETCHS) -> None: |
| 103 | if seed is None: seed = randint(0, 0xffffffff) |
| 104 | self.monitor_var = variable |
| 105 | self.best_loss = float(1e9) |
| 106 | self.seed = seed |
| 107 | self.references = [] |
| 108 | self.outputs = [] |
| 109 | self.fetchs = fetchs |
| 110 | self.random_fetch = random_fetch |
| 111 | |
| 112 | def push(self, tensor: torch.Tensor, is_reference: bool) -> None: |
| 113 | if self.random_fetch: |
| 114 | tensor = batch_random_fetch(tensor, seed=self.seed, fetches_per_batch=self.fetchs) |
| 115 | if is_reference: self.references.append(tensor) |
| 116 | else: self.outputs.append(tensor) |
| 117 | |
| 118 | def pop(self) -> Tuple[torch.Tensor]: |
| 119 | assert len(self.outputs) == len(self.references), ('Inconsistent samples detected.' |
| 120 | f'Reference output gets {len(self.references)} samples, however output has {len(self.outputs)}.') |
| 121 | |
| 122 | return self.outputs, self.references |
| 123 | |
| 124 | def clear(self): |
| 125 | self.outputs.clear() |
| 126 | |
| 127 | |
| 128 | class RandomMemDataset: |
no outgoing calls
no test coverage detected