| 62 | |
| 63 | |
| 64 | def test_module(): |
| 65 | denses_features_replaces = [] |
| 66 | |
| 67 | def _replace_module(module): |
| 68 | if isinstance(module, nn.Dense): |
| 69 | denses_features_replaces.append(module.features) |
| 70 | return WrapperModule(wrapped=module) |
| 71 | else: |
| 72 | return module |
| 73 | |
| 74 | model = MyModule() |
| 75 | with peft.ModuleInterceptor(_replace_module): |
| 76 | out, params = model.init_with_output(jax.random.key(0), jnp.zeros((3,))) |
| 77 | |
| 78 | # We only care about the structure, not the values. |
| 79 | out = jax.tree.map(lambda x: None, out) |
| 80 | params = jax.tree.map(lambda x: None, params) |
| 81 | |
| 82 | assert denses_features_replaces == [1, 2] |
| 83 | assert out == { |
| 84 | 'y0': None, |
| 85 | 'y1': {'wrapped_1': None}, |
| 86 | 'y2': {'wrapped_1': None}, |
| 87 | 'y3': {'wrapped_2': None}, |
| 88 | } |
| 89 | assert params == { |
| 90 | 'params': { |
| 91 | 'Dense_0': { |
| 92 | 'extra_param': None, |
| 93 | 'bias': None, |
| 94 | 'kernel': None, |
| 95 | }, |
| 96 | 'Dense_1': { |
| 97 | 'extra_param': None, |
| 98 | 'bias': None, |
| 99 | 'kernel': None, |
| 100 | }, |
| 101 | }, |
| 102 | } |
| 103 | |
| 104 | |
| 105 | def test_module_non_share_scope(): |