MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / freeze_variables

Function freeze_variables

tensorpack/tfutils/varreplace.py:61–129  ·  view source on GitHub ↗

Return a context to freeze variables, by wrapping ``tf.get_variable`` with a custom getter. It works by either applying ``tf.stop_gradient`` on the variables, or keeping them out of the ``TRAINABLE_VARIABLES`` collection, or both. Both options have their own pros and cons.

(stop_gradient=True, skip_collection=False)

Source from the content-addressed store, hash-verified

59
60
61def freeze_variables(stop_gradient=True, skip_collection=False):
62 """
63 Return a context to freeze variables,
64 by wrapping ``tf.get_variable`` with a custom getter.
65 It works by either applying ``tf.stop_gradient`` on the variables,
66 or keeping them out of the ``TRAINABLE_VARIABLES`` collection, or
67 both. Both options have their own pros and cons.
68
69 Example:
70 .. code-block:: python
71
72 from tensorpack.tfutils import varreplace
73 with varreplace.freeze_variable(stop_gradient=False, skip_collection=True):
74 x = FullyConnected('fc', x, 1000) # fc/* will not be trained
75
76 Args:
77 stop_gradient (bool): if True, variables returned from `get_variable`
78 will be wrapped with `tf.stop_gradient`.
79
80 Note that the created variables may still have gradient when accessed
81 by other approaches (e.g. by name, or by collection).
82 For example, they may still have a gradient in weight decay.
83 Also note that this makes `tf.get_variable` returns a Tensor instead of a Variable,
84 which may break existing contract.
85 Therefore, it's recommended to use the `skip_collection` option instead.
86 skip_collection (bool): if True, do not add the variable to
87 ``TRAINABLE_VARIABLES`` collection, but to ``MODEL_VARIABLES``
88 collection. As a result they will not be trained by default.
89
90 Note:
91
92 `stop_gradient` only stops variables returned by `get_variable` **within the context** to
93 contribute no gradient in this context. Therefore it may not completely freeze the variables.
94 For example:
95
96 1. If a variable is created, or reused outside of the context, it can still contribute to the
97 gradient of other tensors.
98 2. If a freezed variable is accessed by other approaches (e.g., by names, by collections),
99 it can still contribute to the gradient of other tensors.
100 For example, weight decay cannot be stopped by a `stop_gradient` context.
101
102 `skip_collection` has to be used the first time the variable is created.
103 Once `skip_collection` is used, the variable is not a trainable variable anymore,
104 and will be completely freezed from gradient update in tensorpack's single-cost trainer.
105
106 Choose the option carefully depend on what you need.
107 """
108 def custom_getter(getter, *args, **kwargs):
109 trainable = kwargs.get('trainable', True)
110 name = args[0] if len(args) else kwargs.get('name')
111 if skip_collection:
112 kwargs['trainable'] = False
113 v = getter(*args, **kwargs)
114 # do not perform unnecessary changes if it's not originally trainable
115 # otherwise the variable may get added to MODEL_VARIABLES twice
116 if trainable and skip_collection:
117 if isinstance(v, tf.Variable):
118 tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)

Callers 1

backbone_scopeFunction · 0.90

Calls 1

custom_getter_scopeFunction · 0.85

Tested by

no test coverage detected