Encodes a scalar into a spark dataframe field.
| 213 | |
| 214 | |
| 215 | class ScalarCodec(DataframeColumnCodec): |
| 216 | """Encodes a scalar into a spark dataframe field.""" |
| 217 | |
| 218 | def __init__(self, spark_type): |
| 219 | """Constructs a codec. |
| 220 | |
| 221 | :param spark_type: an instance of a Type object from :mod:`pyspark.sql.types` |
| 222 | """ |
| 223 | self._spark_type = spark_type |
| 224 | |
| 225 | def encode(self, unischema_field, value): |
| 226 | # Lazy loading pyspark to avoid creating pyspark dependency on data reading code path |
| 227 | # (currently works only with make_batch_reader). We should move all pyspark related code into a separate module |
| 228 | import pyspark.sql.types as sql_types |
| 229 | |
| 230 | # We treat ndarrays with shape=() as scalars |
| 231 | unsized_numpy_array = isinstance(value, np.ndarray) and value.shape == () |
| 232 | # Validate the input to be a scalar (or an unsized numpy array) |
| 233 | if not unsized_numpy_array and hasattr(value, '__len__') and (not isinstance(value, str)): |
| 234 | raise TypeError('Expected a scalar as a value for field \'{}\'. ' |
| 235 | 'Got a non-numpy type\'{}\''.format(unischema_field.name, type(value))) |
| 236 | |
| 237 | if unischema_field.shape: |
| 238 | raise ValueError('The shape field of unischema_field \'%s\' must be an empty tuple (i.e. \'()\' ' |
| 239 | 'to indicate a scalar. However, the actual shape is %s', |
| 240 | unischema_field.name, unischema_field.shape) |
| 241 | if isinstance(self._spark_type, (sql_types.ByteType, sql_types.ShortType, sql_types.IntegerType, |
| 242 | sql_types.LongType)): |
| 243 | return int(value) |
| 244 | if isinstance(self._spark_type, (sql_types.FloatType, sql_types.DoubleType)): |
| 245 | return float(value) |
| 246 | if isinstance(self._spark_type, sql_types.BooleanType): |
| 247 | return bool(value) |
| 248 | if isinstance(self._spark_type, sql_types.StringType): |
| 249 | if not isinstance(value, str): |
| 250 | raise ValueError( |
| 251 | 'Expected a string value for field {}. Got type {}'.format(unischema_field.name, type(value))) |
| 252 | return str(value) |
| 253 | |
| 254 | return value |
| 255 | |
| 256 | def decode(self, unischema_field, value): |
| 257 | # We are using pyarrow.serialize that does not support Decimal field types. |
| 258 | # Tensorflow does not support Decimal types neither. We convert all decimals to |
| 259 | # strings hence prevent Decimals from getting into anywhere in the reader. We may |
| 260 | # choose to resurrect Decimals support in the future. |
| 261 | return unischema_field.numpy_dtype(value) |
| 262 | |
| 263 | def spark_dtype(self): |
| 264 | return self._spark_type |
| 265 | |
| 266 | def __str__(self): |
| 267 | """Represent this as the following form: |
| 268 | |
| 269 | >>> ScalarCodec(spark_type) |
| 270 | """ |
| 271 | return f'{type(self).__name__}({type(self._spark_type).__name__}())' |
| 272 |
no outgoing calls
searching dependent graphs…