Args: custom_getter: the same as in :func:`tf.get_variable` Returns: The current variable scope with a custom_getter.
(custom_getter)
| 13 | |
| 14 | @contextmanager |
| 15 | def custom_getter_scope(custom_getter): |
| 16 | """ |
| 17 | Args: |
| 18 | custom_getter: the same as in :func:`tf.get_variable` |
| 19 | |
| 20 | Returns: |
| 21 | The current variable scope with a custom_getter. |
| 22 | """ |
| 23 | scope = tf.get_variable_scope() |
| 24 | if get_tf_version_tuple() >= (1, 5): |
| 25 | with tf.variable_scope( |
| 26 | scope, custom_getter=custom_getter, |
| 27 | auxiliary_name_scope=False): |
| 28 | yield |
| 29 | else: |
| 30 | ns = tf.get_default_graph().get_name_scope() |
| 31 | with tf.variable_scope( |
| 32 | scope, custom_getter=custom_getter): |
| 33 | with tf.name_scope(ns + '/' if ns else ''): |
| 34 | yield |
| 35 | |
| 36 | |
| 37 | def remap_variables(fn): |
no test coverage detected