MCPcopy
hub / github.com/uber/petastorm / run

Function run

examples/spark_dataset_converter/tensorflow_converter_example.py:56–107  ·  view source on GitHub ↗
(data_dir)

Source from the content-addressed store, hash-verified

54
55
56def run(data_dir):
57 # Get SparkSession
58 spark = SparkSession.builder \
59 .master("local[2]") \
60 .appName("petastorm.spark tensorflow_example") \
61 .getOrCreate()
62
63 # Load and preprocess data using Spark
64 df = spark.read.format("libsvm") \
65 .option("numFeatures", "784") \
66 .load(data_dir) \
67 .select(col("features"), col("label").cast("long").alias("label"))
68
69 # Randomly split data into train and test dataset
70 df_train, df_test = df.randomSplit([0.9, 0.1], seed=12345)
71
72 # Set a cache directory for intermediate data.
73 # The path should be accessible by both Spark workers and driver.
74 spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF,
75 "file:///tmp/petastorm/cache/tf-example")
76
77 converter_train = make_spark_converter(df_train)
78 converter_test = make_spark_converter(df_test)
79
80 def train_and_evaluate(_=None):
81 import tensorflow.compat.v1 as tf # pylint: disable=import-error
82
83 with converter_train.make_tf_dataset() as dataset:
84 dataset = dataset.map(lambda x: (tf.reshape(x.features, [-1, 28, 28]), x.label))
85 model = train(dataset)
86
87 with converter_test.make_tf_dataset(num_epochs=1) as dataset:
88 dataset = dataset.map(lambda x: (tf.reshape(x.features, [-1, 28, 28]), x.label))
89 hist = model.evaluate(dataset)
90
91 return hist[1]
92
93 # Train and evaluate the model on the local machine
94 accuracy = train_and_evaluate()
95 logging.info("Train and evaluate the model on the local machine.")
96 logging.info("Accuracy: %.6f", accuracy)
97
98 # Train and evaluate the model on a spark worker
99 accuracy = spark.sparkContext.parallelize(range(1)).map(train_and_evaluate).collect()[0]
100 logging.info("Train and evaluate the model remotely on a spark worker, "
101 "which can be used for distributed hyperparameter tuning.")
102 logging.info("Accuracy: %.6f", accuracy)
103
104 # Cleanup
105 converter_train.delete()
106 converter_test.delete()
107 spark.stop()
108
109
110def main():

Callers 3

mainFunction · 0.70

Calls 5

make_spark_converterFunction · 0.90
setMethod · 0.80
deleteMethod · 0.80
train_and_evaluateFunction · 0.70
stopMethod · 0.45

Tested by 2