| 54 | |
| 55 | |
| 56 | def run(data_dir): |
| 57 | # Get SparkSession |
| 58 | spark = SparkSession.builder \ |
| 59 | .master("local[2]") \ |
| 60 | .appName("petastorm.spark tensorflow_example") \ |
| 61 | .getOrCreate() |
| 62 | |
| 63 | # Load and preprocess data using Spark |
| 64 | df = spark.read.format("libsvm") \ |
| 65 | .option("numFeatures", "784") \ |
| 66 | .load(data_dir) \ |
| 67 | .select(col("features"), col("label").cast("long").alias("label")) |
| 68 | |
| 69 | # Randomly split data into train and test dataset |
| 70 | df_train, df_test = df.randomSplit([0.9, 0.1], seed=12345) |
| 71 | |
| 72 | # Set a cache directory for intermediate data. |
| 73 | # The path should be accessible by both Spark workers and driver. |
| 74 | spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, |
| 75 | "file:///tmp/petastorm/cache/tf-example") |
| 76 | |
| 77 | converter_train = make_spark_converter(df_train) |
| 78 | converter_test = make_spark_converter(df_test) |
| 79 | |
| 80 | def train_and_evaluate(_=None): |
| 81 | import tensorflow.compat.v1 as tf # pylint: disable=import-error |
| 82 | |
| 83 | with converter_train.make_tf_dataset() as dataset: |
| 84 | dataset = dataset.map(lambda x: (tf.reshape(x.features, [-1, 28, 28]), x.label)) |
| 85 | model = train(dataset) |
| 86 | |
| 87 | with converter_test.make_tf_dataset(num_epochs=1) as dataset: |
| 88 | dataset = dataset.map(lambda x: (tf.reshape(x.features, [-1, 28, 28]), x.label)) |
| 89 | hist = model.evaluate(dataset) |
| 90 | |
| 91 | return hist[1] |
| 92 | |
| 93 | # Train and evaluate the model on the local machine |
| 94 | accuracy = train_and_evaluate() |
| 95 | logging.info("Train and evaluate the model on the local machine.") |
| 96 | logging.info("Accuracy: %.6f", accuracy) |
| 97 | |
| 98 | # Train and evaluate the model on a spark worker |
| 99 | accuracy = spark.sparkContext.parallelize(range(1)).map(train_and_evaluate).collect()[0] |
| 100 | logging.info("Train and evaluate the model remotely on a spark worker, " |
| 101 | "which can be used for distributed hyperparameter tuning.") |
| 102 | logging.info("Accuracy: %.6f", accuracy) |
| 103 | |
| 104 | # Cleanup |
| 105 | converter_train.delete() |
| 106 | converter_test.delete() |
| 107 | spark.stop() |
| 108 | |
| 109 | |
| 110 | def main(): |