(
self, name: str, value: np.ndarray, dtype: torch.dtype = torch.float32
)
| 24 | self.rescale_cfg = rescale_cfg |
| 25 | |
| 26 | def register( |
| 27 | self, name: str, value: np.ndarray, dtype: torch.dtype = torch.float32 |
| 28 | ) -> None: |
| 29 | self.register_buffer(name, torch.tensor(value, dtype=dtype)) |
| 30 | |
| 31 | def get_cfg_scale(self, default_cfg_scale: float, model_t: int) -> float: |
| 32 | if self.rescale_cfg and default_cfg_scale > 1: |
no outgoing calls
no test coverage detected