MCPcopy
hub / github.com/google-deepmind/gemma / test_module

Function test_module

gemma/peft/_interceptors_test.py:64–102  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

62
63
64def test_module():
65 denses_features_replaces = []
66
67 def _replace_module(module):
68 if isinstance(module, nn.Dense):
69 denses_features_replaces.append(module.features)
70 return WrapperModule(wrapped=module)
71 else:
72 return module
73
74 model = MyModule()
75 with peft.ModuleInterceptor(_replace_module):
76 out, params = model.init_with_output(jax.random.key(0), jnp.zeros((3,)))
77
78 # We only care about the structure, not the values.
79 out = jax.tree.map(lambda x: None, out)
80 params = jax.tree.map(lambda x: None, params)
81
82 assert denses_features_replaces == [1, 2]
83 assert out == {
84 'y0': None,
85 'y1': {'wrapped_1': None},
86 'y2': {'wrapped_1': None},
87 'y3': {'wrapped_2': None},
88 }
89 assert params == {
90 'params': {
91 'Dense_0': {
92 'extra_param': None,
93 'bias': None,
94 'kernel': None,
95 },
96 'Dense_1': {
97 'extra_param': None,
98 'bias': None,
99 'kernel': None,
100 },
101 },
102 }
103
104
105def test_module_non_share_scope():

Callers

nothing calls this directly

Calls 2

MyModuleClass · 0.70
mapMethod · 0.45

Tested by

no test coverage detected