(x: TensorData)
| 45 | "return: [any...]", |
| 46 | ) |
| 47 | def to_default_float(x: TensorData) -> tf.Tensor: |
| 48 | if not tf.is_tensor(x): |
| 49 | # workaround for the fact that tf.cast(, dtype=tf.float64) doesn't directly convert |
| 50 | # python floats to tf.float64 tensors. Instead, it converts the python float to a |
| 51 | # tf.float32 tensor, and then casts that to be tf.float64. This results in a loss |
| 52 | # of precision. See https://github.com/tensorflow/tensorflow/issues/57779 for more context. |
| 53 | return tf.convert_to_tensor(x, default_float()) |
| 54 | return tf.cast(x, dtype=default_float()) |
| 55 | |
| 56 | |
| 57 | def set_trainable(model: Union[tf.Module, Iterable[tf.Module]], flag: bool) -> None: |
searching dependent graphs…