MCPcopy
hub / github.com/horovod/horovod / KerasEstimator

Class KerasEstimator

horovod/spark/keras/estimator.py:98–385  ·  view source on GitHub ↗

Spark Estimator for fitting Keras models to a DataFrame. Supports standalone `keras` and `tf.keras`, and TensorFlow 1.X and 2.X. Args: num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`. data_module: (Optional) DataModule class used for training

Source from the content-addressed store, hash-verified

96
97
98class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable,
99 KerasEstimatorParamsWritable):
100 """Spark Estimator for fitting Keras models to a DataFrame.
101
102 Supports standalone `keras` and `tf.keras`, and TensorFlow 1.X and 2.X.
103
104 Args:
105 num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`.
106 data_module: (Optional) DataModule class used for training and validation, if not set, defaults to the PetastormDataModule.
107 model: Keras model to train.
108 backend: Optional Backend object for running distributed training function. Defaults to SparkBackend with
109 `num_proc` worker processes. Cannot be specified if `num_proc` is also provided.
110 store: Store object that abstracts reading and writing of intermediate data and run results.
111 custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered
112 during serialization/deserialization.
113 optimizer: Keras optimizer to be converted into a `hvd.DistributedOptimizer` for training.
114 loss: Keras loss or list of losses.
115 loss_weights: Optional list of float weight values to assign each loss.
116 sample_weight_col: Optional column indicating the weight of each sample.
117 gradient_compression: Gradient compression used by `hvd.DistributedOptimizer`.
118 metrics: Optional metrics to record.
119 feature_cols: Column names used as feature inputs to the model. Must be a list with each feature
120 mapping to a sequential argument in the model's forward() function.
121 label_cols: Column names used as labels. Must be a list with one label for each output of the model.
122 validation: Optional validation column name (string) where every row in the column is either 1/True or 0/False,
123 or validation split (float) giving percent of data to be randomly selected for validation.
124 callbacks: Keras callbacks.
125 batch_size: Number of rows from the DataFrame per batch.
126 val_batch_size: Number of rows from the DataFrame per batch for validation, if not set, will use batch_size.
127 epochs: Number of epochs to train.
128 verbose: Verbosity level [0, 2] (default: 1).
129 random_seed: Optional random seed to use for Tensorflow. Default: None.
130 shuffle_buffer_size: (Deprecated) Optional size of in-memory shuffle buffer in rows (on training data).
131 Allocating a larger buffer size increases randomness of shuffling at
132 the cost of more host memory. Defaults to estimating with an assumption
133 of 4GB of memory per host. Set shuffle_buffer_size=0 would turn off shuffle.
134 shuffle: (Optional) Whether to shuffle training samples or not. Defaults to True.
135 partitions_per_process: Number of Parquet partitions to assign per worker process from `num_proc` (default: 10).
136 run_id: Optional unique ID for this run for organization in the Store. Will be automatically assigned if not
137 provided.
138 train_steps_per_epoch: Number of steps to train each epoch. Useful for testing that model trains successfully.
139 Defaults to training the entire dataset each epoch.
140 validation_steps_per_epoch: Number of validation steps to perform each epoch.
141 transformation_fn: Optional function that takes a row as its parameter
142 and returns a modified row that is then fed into the
143 train or validation step. This transformation is
144 applied after batching. See Petastorm [TransformSpec](https://github.com/uber/petastorm/blob/master/petastorm/transform.py)
145 for more details. Note that this fucntion constructs
146 another function which should perform the
147 transformation.
148 train_reader_num_workers: This parameter specifies the number of parallel processes that
149 read the training data from data store and apply data
150 transformations to it. Increasing this number
151 will generally increase the reading rate but will also
152 increase the memory footprint. More processes are
153 particularly useful if the bandwidth to the data store is not
154 high enough, or users need to apply transformation such as
155 decompression or data augmentation on raw data.

Callers

nothing calls this directly

Calls 1

_dummyMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…