MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / post_process_model

Method post_process_model

tensorpack/contrib/keras.py:56–76  ·  view source on GitHub ↗
(model)

Source from the content-addressed store, hash-verified

54 update_ops_backup = backup_collection([tf.GraphKeys.UPDATE_OPS])
55
56 def post_process_model(model):
57 added_trainable_names = {x.name for x in tf.trainable_variables()}
58 restore_collection(trainable_backup)
59
60 for v in model.weights:
61 # In Keras, the collection is not respected and could contain non-trainable vars.
62 # We put M.weights into the collection instead.
63 if v.name not in old_trainable_names and v.name in added_trainable_names:
64 tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, v)
65 new_trainable_names = {x.name for x in tf.trainable_variables()}
66
67 for n in added_trainable_names:
68 if n not in new_trainable_names:
69 logger.warn("Keras created trainable variable '{}' which is actually not trainable. "
70 "This was automatically corrected.".format(n))
71
72 # Keras models might not use this collection at all (in some versions).
73 # This is a BC-breaking change of tf.keras: https://github.com/tensorflow/tensorflow/issues/19643
74 restore_collection(update_ops_backup)
75 for op in model.updates:
76 tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, op)
77
78 if self.cached_model is None:
79 assert not reuse

Callers

nothing calls this directly

Calls 2

restore_collectionFunction · 0.85
formatMethod · 0.80

Tested by

no test coverage detected