MCPcopy
hub / github.com/uber/petastorm / ScalarCodec

Class ScalarCodec

petastorm/codecs.py:215–271  ·  view source on GitHub ↗

Encodes a scalar into a spark dataframe field.

Source from the content-addressed store, hash-verified

213
214
215class ScalarCodec(DataframeColumnCodec):
216 """Encodes a scalar into a spark dataframe field."""
217
218 def __init__(self, spark_type):
219 """Constructs a codec.
220
221 :param spark_type: an instance of a Type object from :mod:`pyspark.sql.types`
222 """
223 self._spark_type = spark_type
224
225 def encode(self, unischema_field, value):
226 # Lazy loading pyspark to avoid creating pyspark dependency on data reading code path
227 # (currently works only with make_batch_reader). We should move all pyspark related code into a separate module
228 import pyspark.sql.types as sql_types
229
230 # We treat ndarrays with shape=() as scalars
231 unsized_numpy_array = isinstance(value, np.ndarray) and value.shape == ()
232 # Validate the input to be a scalar (or an unsized numpy array)
233 if not unsized_numpy_array and hasattr(value, '__len__') and (not isinstance(value, str)):
234 raise TypeError('Expected a scalar as a value for field \'{}\'. '
235 'Got a non-numpy type\'{}\''.format(unischema_field.name, type(value)))
236
237 if unischema_field.shape:
238 raise ValueError('The shape field of unischema_field \'%s\' must be an empty tuple (i.e. \'()\' '
239 'to indicate a scalar. However, the actual shape is %s',
240 unischema_field.name, unischema_field.shape)
241 if isinstance(self._spark_type, (sql_types.ByteType, sql_types.ShortType, sql_types.IntegerType,
242 sql_types.LongType)):
243 return int(value)
244 if isinstance(self._spark_type, (sql_types.FloatType, sql_types.DoubleType)):
245 return float(value)
246 if isinstance(self._spark_type, sql_types.BooleanType):
247 return bool(value)
248 if isinstance(self._spark_type, sql_types.StringType):
249 if not isinstance(value, str):
250 raise ValueError(
251 'Expected a string value for field {}. Got type {}'.format(unischema_field.name, type(value)))
252 return str(value)
253
254 return value
255
256 def decode(self, unischema_field, value):
257 # We are using pyarrow.serialize that does not support Decimal field types.
258 # Tensorflow does not support Decimal types neither. We convert all decimals to
259 # strings hence prevent Decimals from getting into anywhere in the reader. We may
260 # choose to resurrect Decimals support in the future.
261 return unischema_field.numpy_dtype(value)
262
263 def spark_dtype(self):
264 return self._spark_type
265
266 def __str__(self):
267 """Represent this as the following form:
268
269 >>> ScalarCodec(spark_type)
270 """
271 return f'{type(self).__name__}({type(self._spark_type).__name__}())'
272

Calls

no outgoing calls

Used in the wild real call sites across dependent graphs

searching dependent graphs…