(self, unischema_field, value)
| 223 | self._spark_type = spark_type |
| 224 | |
| 225 | def encode(self, unischema_field, value): |
| 226 | # Lazy loading pyspark to avoid creating pyspark dependency on data reading code path |
| 227 | # (currently works only with make_batch_reader). We should move all pyspark related code into a separate module |
| 228 | import pyspark.sql.types as sql_types |
| 229 | |
| 230 | # We treat ndarrays with shape=() as scalars |
| 231 | unsized_numpy_array = isinstance(value, np.ndarray) and value.shape == () |
| 232 | # Validate the input to be a scalar (or an unsized numpy array) |
| 233 | if not unsized_numpy_array and hasattr(value, '__len__') and (not isinstance(value, str)): |
| 234 | raise TypeError('Expected a scalar as a value for field \'{}\'. ' |
| 235 | 'Got a non-numpy type\'{}\''.format(unischema_field.name, type(value))) |
| 236 | |
| 237 | if unischema_field.shape: |
| 238 | raise ValueError('The shape field of unischema_field \'%s\' must be an empty tuple (i.e. \'()\' ' |
| 239 | 'to indicate a scalar. However, the actual shape is %s', |
| 240 | unischema_field.name, unischema_field.shape) |
| 241 | if isinstance(self._spark_type, (sql_types.ByteType, sql_types.ShortType, sql_types.IntegerType, |
| 242 | sql_types.LongType)): |
| 243 | return int(value) |
| 244 | if isinstance(self._spark_type, (sql_types.FloatType, sql_types.DoubleType)): |
| 245 | return float(value) |
| 246 | if isinstance(self._spark_type, sql_types.BooleanType): |
| 247 | return bool(value) |
| 248 | if isinstance(self._spark_type, sql_types.StringType): |
| 249 | if not isinstance(value, str): |
| 250 | raise ValueError( |
| 251 | 'Expected a string value for field {}. Got type {}'.format(unischema_field.name, type(value))) |
| 252 | return str(value) |
| 253 | |
| 254 | return value |
| 255 | |
| 256 | def decode(self, unischema_field, value): |
| 257 | # We are using pyarrow.serialize that does not support Decimal field types. |
no outgoing calls