Spark Estimator for fitting Keras models to a DataFrame. Supports standalone `keras` and `tf.keras`, and TensorFlow 1.X and 2.X. Args: num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`. data_module: (Optional) DataModule class used for training
| 96 | |
| 97 | |
| 98 | class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable, |
| 99 | KerasEstimatorParamsWritable): |
| 100 | """Spark Estimator for fitting Keras models to a DataFrame. |
| 101 | |
| 102 | Supports standalone `keras` and `tf.keras`, and TensorFlow 1.X and 2.X. |
| 103 | |
| 104 | Args: |
| 105 | num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`. |
| 106 | data_module: (Optional) DataModule class used for training and validation, if not set, defaults to the PetastormDataModule. |
| 107 | model: Keras model to train. |
| 108 | backend: Optional Backend object for running distributed training function. Defaults to SparkBackend with |
| 109 | `num_proc` worker processes. Cannot be specified if `num_proc` is also provided. |
| 110 | store: Store object that abstracts reading and writing of intermediate data and run results. |
| 111 | custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered |
| 112 | during serialization/deserialization. |
| 113 | optimizer: Keras optimizer to be converted into a `hvd.DistributedOptimizer` for training. |
| 114 | loss: Keras loss or list of losses. |
| 115 | loss_weights: Optional list of float weight values to assign each loss. |
| 116 | sample_weight_col: Optional column indicating the weight of each sample. |
| 117 | gradient_compression: Gradient compression used by `hvd.DistributedOptimizer`. |
| 118 | metrics: Optional metrics to record. |
| 119 | feature_cols: Column names used as feature inputs to the model. Must be a list with each feature |
| 120 | mapping to a sequential argument in the model's forward() function. |
| 121 | label_cols: Column names used as labels. Must be a list with one label for each output of the model. |
| 122 | validation: Optional validation column name (string) where every row in the column is either 1/True or 0/False, |
| 123 | or validation split (float) giving percent of data to be randomly selected for validation. |
| 124 | callbacks: Keras callbacks. |
| 125 | batch_size: Number of rows from the DataFrame per batch. |
| 126 | val_batch_size: Number of rows from the DataFrame per batch for validation, if not set, will use batch_size. |
| 127 | epochs: Number of epochs to train. |
| 128 | verbose: Verbosity level [0, 2] (default: 1). |
| 129 | random_seed: Optional random seed to use for Tensorflow. Default: None. |
| 130 | shuffle_buffer_size: (Deprecated) Optional size of in-memory shuffle buffer in rows (on training data). |
| 131 | Allocating a larger buffer size increases randomness of shuffling at |
| 132 | the cost of more host memory. Defaults to estimating with an assumption |
| 133 | of 4GB of memory per host. Set shuffle_buffer_size=0 would turn off shuffle. |
| 134 | shuffle: (Optional) Whether to shuffle training samples or not. Defaults to True. |
| 135 | partitions_per_process: Number of Parquet partitions to assign per worker process from `num_proc` (default: 10). |
| 136 | run_id: Optional unique ID for this run for organization in the Store. Will be automatically assigned if not |
| 137 | provided. |
| 138 | train_steps_per_epoch: Number of steps to train each epoch. Useful for testing that model trains successfully. |
| 139 | Defaults to training the entire dataset each epoch. |
| 140 | validation_steps_per_epoch: Number of validation steps to perform each epoch. |
| 141 | transformation_fn: Optional function that takes a row as its parameter |
| 142 | and returns a modified row that is then fed into the |
| 143 | train or validation step. This transformation is |
| 144 | applied after batching. See Petastorm [TransformSpec](https://github.com/uber/petastorm/blob/master/petastorm/transform.py) |
| 145 | for more details. Note that this fucntion constructs |
| 146 | another function which should perform the |
| 147 | transformation. |
| 148 | train_reader_num_workers: This parameter specifies the number of parallel processes that |
| 149 | read the training data from data store and apply data |
| 150 | transformations to it. Increasing this number |
| 151 | will generally increase the reading rate but will also |
| 152 | increase the memory footprint. More processes are |
| 153 | particularly useful if the bandwidth to the data store is not |
| 154 | high enough, or users need to apply transformation such as |
| 155 | decompression or data augmentation on raw data. |
nothing calls this directly
no test coverage detected
searching dependent graphs…