A wrapper class to define the common interface used by booster. Args: module (nn.Module): The model to be wrapped.
| 100 | |
| 101 | |
| 102 | class ModelWrapper(nn.Module): |
| 103 | """ |
| 104 | A wrapper class to define the common interface used by booster. |
| 105 | |
| 106 | Args: |
| 107 | module (nn.Module): The model to be wrapped. |
| 108 | """ |
| 109 | |
| 110 | def __init__(self, module: nn.Module) -> None: |
| 111 | super().__init__() |
| 112 | self.module = module |
| 113 | |
| 114 | def unwrap(self, unwrap_peft: bool = True): |
| 115 | """ |
| 116 | Unwrap the model to return the original model for checkpoint saving/loading. |
| 117 | """ |
| 118 | if isinstance(self.module, ModelWrapper): |
| 119 | model = self.module.unwrap() |
| 120 | else: |
| 121 | model = self.module |
| 122 | if unwrap_peft and isinstance(model, PeftModel): |
| 123 | model = PeftUnwrapMixin(model) |
| 124 | return model |
| 125 | |
| 126 | def forward(self, *args, **kwargs): |
| 127 | return self.module(*args, **kwargs) |
| 128 | |
| 129 | |
| 130 | class AMPModelMixin: |
no outgoing calls
no test coverage detected
searching dependent graphs…