Set trainable flag for all :class:`tf.Variable`\ s and :class:`gpflow.Parameter`\ s in a :class:`tf.Module` or collection of :class:`tf.Module`\ s.
(model: Union[tf.Module, Iterable[tf.Module]], flag: bool)
| 55 | |
| 56 | |
| 57 | def set_trainable(model: Union[tf.Module, Iterable[tf.Module]], flag: bool) -> None: |
| 58 | """ |
| 59 | Set trainable flag for all :class:`tf.Variable`\ s and :class:`gpflow.Parameter`\ s in a |
| 60 | :class:`tf.Module` or collection of :class:`tf.Module`\ s. |
| 61 | """ |
| 62 | modules = [model] if isinstance(model, tf.Module) else model |
| 63 | |
| 64 | for mod in modules: |
| 65 | for variable in mod.variables: |
| 66 | variable._trainable = flag |
| 67 | |
| 68 | |
| 69 | def is_variable(t: TensorData) -> bool: |
no outgoing calls
searching dependent graphs…