Encodes numpy ndarray into, or decodes an ndarray from, a spark dataframe field.
| 131 | |
| 132 | |
| 133 | class NdarrayCodec(DataframeColumnCodec): |
| 134 | """Encodes numpy ndarray into, or decodes an ndarray from, a spark dataframe field.""" |
| 135 | |
| 136 | def encode(self, unischema_field, value): |
| 137 | expected_dtype = unischema_field.numpy_dtype |
| 138 | if isinstance(value, np.ndarray): |
| 139 | if expected_dtype != value.dtype.type: |
| 140 | raise ValueError('Unexpected type of {} feature. ' |
| 141 | 'Expected {}. Got {}'.format(unischema_field.name, expected_dtype, value.dtype)) |
| 142 | |
| 143 | expected_shape = unischema_field.shape |
| 144 | if not _is_compliant_shape(value.shape, expected_shape): |
| 145 | raise ValueError('Unexpected dimensions of {} feature. ' |
| 146 | 'Expected {}. Got {}'.format(unischema_field.name, expected_shape, value.shape)) |
| 147 | else: |
| 148 | raise ValueError('Unexpected type of {} feature. ' |
| 149 | 'Expected ndarray of {}. Got {}'.format(unischema_field.name, expected_dtype, type(value))) |
| 150 | |
| 151 | memfile = BytesIO() |
| 152 | np.save(memfile, value) |
| 153 | return bytearray(memfile.getvalue()) |
| 154 | |
| 155 | def decode(self, unischema_field, value): |
| 156 | memfile = BytesIO(value) |
| 157 | return np.load(memfile) |
| 158 | |
| 159 | def spark_dtype(self): |
| 160 | # Lazy loading pyspark to avoid creating pyspark dependency on data reading code path |
| 161 | # (currently works only with make_batch_reader). We should move all pyspark related code into a separate module |
| 162 | import pyspark.sql.types as sql_types |
| 163 | |
| 164 | return sql_types.BinaryType() |
| 165 | |
| 166 | def __str__(self): |
| 167 | """Represent this as the following form: |
| 168 | |
| 169 | >>> NdarrayCodec() |
| 170 | """ |
| 171 | return f'{type(self).__name__}()' |
| 172 | |
| 173 | |
| 174 | class CompressedNdarrayCodec(DataframeColumnCodec): |
no outgoing calls
searching dependent graphs…