Convert a PyTensor tensor or numpy array to pytensor.config.floatX type.
(X)
| 262 | |
| 263 | |
| 264 | def floatX(X): |
| 265 | """Convert a PyTensor tensor or numpy array to pytensor.config.floatX type.""" |
| 266 | try: |
| 267 | return X.astype(pytensor.config.floatX) |
| 268 | except AttributeError: |
| 269 | # Scalar passed |
| 270 | return np.asarray(X, dtype=pytensor.config.floatX) |
| 271 | |
| 272 | |
| 273 | _conversion_map = {"float64": "int32", "float32": "int16", "float16": "int8", "float8": "int8"} |
no outgoing calls