(model)
| 54 | update_ops_backup = backup_collection([tf.GraphKeys.UPDATE_OPS]) |
| 55 | |
| 56 | def post_process_model(model): |
| 57 | added_trainable_names = {x.name for x in tf.trainable_variables()} |
| 58 | restore_collection(trainable_backup) |
| 59 | |
| 60 | for v in model.weights: |
| 61 | # In Keras, the collection is not respected and could contain non-trainable vars. |
| 62 | # We put M.weights into the collection instead. |
| 63 | if v.name not in old_trainable_names and v.name in added_trainable_names: |
| 64 | tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, v) |
| 65 | new_trainable_names = {x.name for x in tf.trainable_variables()} |
| 66 | |
| 67 | for n in added_trainable_names: |
| 68 | if n not in new_trainable_names: |
| 69 | logger.warn("Keras created trainable variable '{}' which is actually not trainable. " |
| 70 | "This was automatically corrected.".format(n)) |
| 71 | |
| 72 | # Keras models might not use this collection at all (in some versions). |
| 73 | # This is a BC-breaking change of tf.keras: https://github.com/tensorflow/tensorflow/issues/19643 |
| 74 | restore_collection(update_ops_backup) |
| 75 | for op in model.updates: |
| 76 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, op) |
| 77 | |
| 78 | if self.cached_model is None: |
| 79 | assert not reuse |
nothing calls this directly
no test coverage detected