Implementation of hyperopt.Trials supporting distributed execution using Apache Spark clusters. This requires fmin to be run on a Spark cluster. Plugging SparkTrials into hyperopt.fmin() allows hyperopt to send model training and evaluation tasks to Spark workers, paralleli
| 25 | |
| 26 | |
| 27 | class SparkTrials(Trials): |
| 28 | """ |
| 29 | Implementation of hyperopt.Trials supporting |
| 30 | distributed execution using Apache Spark clusters. |
| 31 | This requires fmin to be run on a Spark cluster. |
| 32 | |
| 33 | Plugging SparkTrials into hyperopt.fmin() allows hyperopt |
| 34 | to send model training and evaluation tasks to Spark workers, |
| 35 | parallelizing hyperparameter search. |
| 36 | Each trial (set of hyperparameter values) is handled within |
| 37 | a single Spark task; i.e., each model will be fit and evaluated |
| 38 | on a single worker machine. Trials are run asynchronously. |
| 39 | |
| 40 | See hyperopt.Trials docs for general information about Trials. |
| 41 | |
| 42 | The fields we store in our trial docs match the base Trials class. The fields include: |
| 43 | - 'tid': trial ID |
| 44 | - 'state': JOB_STATE_DONE, JOB_STATE_ERROR, etc. |
| 45 | - 'result': evaluation result for completed trial run |
| 46 | - 'refresh_time': timestamp for last status update |
| 47 | - 'misc': includes: |
| 48 | - 'error': (error type, error message) |
| 49 | - 'book_time': timestamp for trial run start |
| 50 | """ |
| 51 | |
| 52 | asynchronous = True |
| 53 | |
| 54 | # Hard cap on the number of concurrent hyperopt tasks (Spark jobs) to run. Set at 128. |
| 55 | MAX_CONCURRENT_JOBS_ALLOWED = 128 |
| 56 | |
| 57 | def __init__( |
| 58 | self, parallelism=None, timeout=None, loss_threshold=None, spark_session=None |
| 59 | ): |
| 60 | """ |
| 61 | :param parallelism: Maximum number of parallel trials to run, |
| 62 | i.e., maximum number of concurrent Spark tasks. |
| 63 | The actual parallelism is subject to available Spark task slots at |
| 64 | runtime. |
| 65 | If set to None (default) or a non-positive value, this will be set to |
| 66 | Spark's default parallelism or `1`. |
| 67 | We cap the value at `MAX_CONCURRENT_JOBS_ALLOWED=128`. |
| 68 | :param timeout: Maximum time (in seconds) which fmin is allowed to take. |
| 69 | If this timeout is hit, then fmin will cancel running and proposed trials. |
| 70 | It will retain all completed trial runs and return the best result found |
| 71 | so far. |
| 72 | :param spark_session: A SparkSession object. If None is passed, SparkTrials will attempt |
| 73 | to use an existing SparkSession or create a new one. SparkSession is |
| 74 | the entry point for various facilities provided by Spark. For more |
| 75 | information, visit the documentation for PySpark. |
| 76 | """ |
| 77 | super().__init__(exp_key=None, refresh=False) |
| 78 | if not _have_spark: |
| 79 | raise Exception( |
| 80 | "SparkTrials cannot import pyspark classes. Make sure that PySpark " |
| 81 | "is available in your environment. E.g., try running 'import pyspark'" |
| 82 | ) |
| 83 | validate_timeout(timeout) |
| 84 | validate_loss_threshold(loss_threshold) |
no outgoing calls