(cls, model: ModelPatcher, stg_flag: STGFlag)
| 206 | |
| 207 | @classmethod |
| 208 | def patch_model(cls, model: ModelPatcher, stg_flag: STGFlag): |
| 209 | transformer_blocks = cls.get_transformer_blocks(model) |
| 210 | |
| 211 | for i, block in enumerate(transformer_blocks): |
| 212 | model.set_model_patch_replace( |
| 213 | STGBlockWrapper(block, stg_flag, i), "dit", "double_block", i |
| 214 | ) |
| 215 | |
| 216 | @staticmethod |
| 217 | def get_transformer_blocks(model: ModelPatcher): |
no test coverage detected