(unet, freeze_model)
| 252 | |
| 253 | |
| 254 | def create_custom_diffusion(unet, freeze_model): |
| 255 | for name, params in unet.named_parameters(): |
| 256 | if freeze_model == 'crossattn': |
| 257 | if 'attn2' in name: |
| 258 | params.requires_grad = True |
| 259 | print(name) |
| 260 | else: |
| 261 | params.requires_grad = False |
| 262 | elif freeze_model == "crossattn_kv": |
| 263 | if 'attn2.to_k' in name or 'attn2.to_v' in name: |
| 264 | params.requires_grad = True |
| 265 | print(name) |
| 266 | else: |
| 267 | params.requires_grad = False |
| 268 | else: |
| 269 | raise ValueError( |
| 270 | "freeze_model argument only supports crossattn_kv or crossattn" |
| 271 | ) |
| 272 | |
| 273 | # change attn class |
| 274 | def change_attn(unet): |
| 275 | for layer in unet.children(): |
| 276 | if type(layer) == Attention: |
| 277 | bound_method = set_use_memory_efficient_attention_xformers.__get__(layer, layer.__class__) |
| 278 | setattr(layer, 'set_use_memory_efficient_attention_xformers', bound_method) |
| 279 | else: |
| 280 | change_attn(layer) |
| 281 | |
| 282 | change_attn(unet) |
| 283 | unet.set_attn_processor(CustomDiffusionAttnProcessor()) |
| 284 | return unet |
| 285 | |
| 286 | |
| 287 | def freeze_params(params): |
no test coverage detected