(self, layer_idx, kwargs)
| 141 | super().__init__(_distributed_cache_data) |
| 142 | |
| 143 | def record_kwargs(self, layer_idx, kwargs): |
| 144 | d = {} |
| 145 | for k, v in kwargs.items(): |
| 146 | if v is not None and torch.is_tensor(v): |
| 147 | d[k] = v.cpu() if v.is_cuda else v.clone() |
| 148 | else: |
| 149 | d[k] = v |
| 150 | super().record_kwargs(layer_idx, d) |
| 151 | |
| 152 | def update( |
| 153 | self, |
no test coverage detected