(self, *args, destination=None, prefix="", keep_vars=False)
| 107 | # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved |
| 108 | # when saving the whole text encoder model and when LoRA is unloaded or fused |
| 109 | def state_dict(self, *args, destination=None, prefix="", keep_vars=False): |
| 110 | if self.lora_linear_layer is None: |
| 111 | return self.regular_linear_layer.state_dict( |
| 112 | *args, destination=destination, prefix=prefix, keep_vars=keep_vars |
| 113 | ) |
| 114 | |
| 115 | return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) |
| 116 | |
| 117 | def _fuse_lora(self, lora_scale=1.0, safe_fusing=False): |
| 118 | if self.lora_linear_layer is None: |
no outgoing calls
no test coverage detected