Replace given type in module to a new type. mostly used in deploy. Args: module (nn.Module): model to apply replace operation. replaced_module_type (Type): module type to be replaced. new_module_type (Type) replace_func (function): python function to describ
(module, replaced_module_type, new_module_type, replace_func=None)
| 75 | |
| 76 | |
| 77 | def replace_module(module, replaced_module_type, new_module_type, replace_func=None): |
| 78 | """ |
| 79 | Replace given type in module to a new type. mostly used in deploy. |
| 80 | |
| 81 | Args: |
| 82 | module (nn.Module): model to apply replace operation. |
| 83 | replaced_module_type (Type): module type to be replaced. |
| 84 | new_module_type (Type) |
| 85 | replace_func (function): python function to describe replace logic. Defalut value None. |
| 86 | |
| 87 | Returns: |
| 88 | model (nn.Module): module that already been replaced. |
| 89 | """ |
| 90 | |
| 91 | def default_replace_func(replaced_module_type, new_module_type): |
| 92 | return new_module_type() |
| 93 | |
| 94 | if replace_func is None: |
| 95 | replace_func = default_replace_func |
| 96 | |
| 97 | model = module |
| 98 | if isinstance(module, replaced_module_type): |
| 99 | model = replace_func(replaced_module_type, new_module_type) |
| 100 | else: # recurrsively replace |
| 101 | for name, child in module.named_children(): |
| 102 | new_child = replace_module(child, replaced_module_type, new_module_type) |
| 103 | if new_child is not child: # child is already replaced |
| 104 | model.add_module(name, new_child) |
| 105 | |
| 106 | return model |