Sets default float type. Available options are `np.float16`, `np.float32`, or `np.float64`.
(value_type: type)
| 283 | |
| 284 | |
| 285 | def set_default_float(value_type: type) -> None: |
| 286 | """ |
| 287 | Sets default float type. Available options are `np.float16`, `np.float32`, |
| 288 | or `np.float64`. |
| 289 | """ |
| 290 | try: |
| 291 | tf_dtype = tf.as_dtype(value_type) # Test that it's a tensorflow-valid dtype |
| 292 | except TypeError: |
| 293 | raise TypeError(f"{value_type} is not a valid tf or np dtype") |
| 294 | |
| 295 | if not tf_dtype.is_floating: |
| 296 | raise TypeError(f"{value_type} is not a float dtype") |
| 297 | |
| 298 | set_config(replace(config(), float=tf_dtype.as_numpy_dtype)) |
| 299 | |
| 300 | |
| 301 | def set_default_jitter(value: float) -> None: |
searching dependent graphs…