Return a context to freeze variables, by wrapping ``tf.get_variable`` with a custom getter. It works by either applying ``tf.stop_gradient`` on the variables, or keeping them out of the ``TRAINABLE_VARIABLES`` collection, or both. Both options have their own pros and cons.
(stop_gradient=True, skip_collection=False)
| 59 | |
| 60 | |
| 61 | def freeze_variables(stop_gradient=True, skip_collection=False): |
| 62 | """ |
| 63 | Return a context to freeze variables, |
| 64 | by wrapping ``tf.get_variable`` with a custom getter. |
| 65 | It works by either applying ``tf.stop_gradient`` on the variables, |
| 66 | or keeping them out of the ``TRAINABLE_VARIABLES`` collection, or |
| 67 | both. Both options have their own pros and cons. |
| 68 | |
| 69 | Example: |
| 70 | .. code-block:: python |
| 71 | |
| 72 | from tensorpack.tfutils import varreplace |
| 73 | with varreplace.freeze_variable(stop_gradient=False, skip_collection=True): |
| 74 | x = FullyConnected('fc', x, 1000) # fc/* will not be trained |
| 75 | |
| 76 | Args: |
| 77 | stop_gradient (bool): if True, variables returned from `get_variable` |
| 78 | will be wrapped with `tf.stop_gradient`. |
| 79 | |
| 80 | Note that the created variables may still have gradient when accessed |
| 81 | by other approaches (e.g. by name, or by collection). |
| 82 | For example, they may still have a gradient in weight decay. |
| 83 | Also note that this makes `tf.get_variable` returns a Tensor instead of a Variable, |
| 84 | which may break existing contract. |
| 85 | Therefore, it's recommended to use the `skip_collection` option instead. |
| 86 | skip_collection (bool): if True, do not add the variable to |
| 87 | ``TRAINABLE_VARIABLES`` collection, but to ``MODEL_VARIABLES`` |
| 88 | collection. As a result they will not be trained by default. |
| 89 | |
| 90 | Note: |
| 91 | |
| 92 | `stop_gradient` only stops variables returned by `get_variable` **within the context** to |
| 93 | contribute no gradient in this context. Therefore it may not completely freeze the variables. |
| 94 | For example: |
| 95 | |
| 96 | 1. If a variable is created, or reused outside of the context, it can still contribute to the |
| 97 | gradient of other tensors. |
| 98 | 2. If a freezed variable is accessed by other approaches (e.g., by names, by collections), |
| 99 | it can still contribute to the gradient of other tensors. |
| 100 | For example, weight decay cannot be stopped by a `stop_gradient` context. |
| 101 | |
| 102 | `skip_collection` has to be used the first time the variable is created. |
| 103 | Once `skip_collection` is used, the variable is not a trainable variable anymore, |
| 104 | and will be completely freezed from gradient update in tensorpack's single-cost trainer. |
| 105 | |
| 106 | Choose the option carefully depend on what you need. |
| 107 | """ |
| 108 | def custom_getter(getter, *args, **kwargs): |
| 109 | trainable = kwargs.get('trainable', True) |
| 110 | name = args[0] if len(args) else kwargs.get('name') |
| 111 | if skip_collection: |
| 112 | kwargs['trainable'] = False |
| 113 | v = getter(*args, **kwargs) |
| 114 | # do not perform unnecessary changes if it's not originally trainable |
| 115 | # otherwise the variable may get added to MODEL_VARIABLES twice |
| 116 | if trainable and skip_collection: |
| 117 | if isinstance(v, tf.Variable): |
| 118 | tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v) |
no test coverage detected