| 103 | |
| 104 | |
| 105 | def test_module_non_share_scope(): |
| 106 | denses_features_replaces = [] |
| 107 | |
| 108 | def _replace_module(module): |
| 109 | if isinstance(module, nn.Dense): |
| 110 | denses_features_replaces.append(module.features) |
| 111 | return WrapperModule(wrapped=module, share_scope=False) |
| 112 | else: |
| 113 | return module |
| 114 | |
| 115 | model = MyModule() |
| 116 | with peft.ModuleInterceptor(_replace_module): |
| 117 | out, params = model.init_with_output(jax.random.key(0), jnp.zeros((3,))) |
| 118 | |
| 119 | # We only care about the structure, not the values. |
| 120 | out = jax.tree.map(lambda x: None, out) |
| 121 | params = jax.tree.map(lambda x: None, params) |
| 122 | |
| 123 | assert denses_features_replaces == [1, 2] |
| 124 | assert out == { |
| 125 | 'y0': None, |
| 126 | 'y1': {'wrapped_1': None}, |
| 127 | 'y2': {'wrapped_1': None}, |
| 128 | 'y3': {'wrapped_2': None}, |
| 129 | } |
| 130 | # TODO(epot): Is it possible to have the `Dense_0` to be nested inside the |
| 131 | # `WrapperModule_0` (By changing the scope or copying the module) ? Is it |
| 132 | # desirable ? |
| 133 | assert params == { |
| 134 | 'params': { |
| 135 | 'Dense_0': { |
| 136 | 'WrapperModule_0': {'extra_param': None}, |
| 137 | 'bias': None, |
| 138 | 'kernel': None, |
| 139 | }, |
| 140 | 'Dense_1': { |
| 141 | 'WrapperModule_0': {'extra_param': None}, |
| 142 | 'bias': None, |
| 143 | 'kernel': None, |
| 144 | }, |
| 145 | }, |
| 146 | } |
| 147 | |
| 148 | |
| 149 | # TODO(epot): Test a nested replace (module replaced also has sub-modules which |