:param parallelism: Maximum number of parallel trials to run, i.e., maximum number of concurrent Spark tasks. The actual parallelism is subject to available Spark task slots at runtime.
(
self, parallelism=None, timeout=None, loss_threshold=None, spark_session=None
)
| 55 | MAX_CONCURRENT_JOBS_ALLOWED = 128 |
| 56 | |
| 57 | def __init__( |
| 58 | self, parallelism=None, timeout=None, loss_threshold=None, spark_session=None |
| 59 | ): |
| 60 | """ |
| 61 | :param parallelism: Maximum number of parallel trials to run, |
| 62 | i.e., maximum number of concurrent Spark tasks. |
| 63 | The actual parallelism is subject to available Spark task slots at |
| 64 | runtime. |
| 65 | If set to None (default) or a non-positive value, this will be set to |
| 66 | Spark's default parallelism or `1`. |
| 67 | We cap the value at `MAX_CONCURRENT_JOBS_ALLOWED=128`. |
| 68 | :param timeout: Maximum time (in seconds) which fmin is allowed to take. |
| 69 | If this timeout is hit, then fmin will cancel running and proposed trials. |
| 70 | It will retain all completed trial runs and return the best result found |
| 71 | so far. |
| 72 | :param spark_session: A SparkSession object. If None is passed, SparkTrials will attempt |
| 73 | to use an existing SparkSession or create a new one. SparkSession is |
| 74 | the entry point for various facilities provided by Spark. For more |
| 75 | information, visit the documentation for PySpark. |
| 76 | """ |
| 77 | super().__init__(exp_key=None, refresh=False) |
| 78 | if not _have_spark: |
| 79 | raise Exception( |
| 80 | "SparkTrials cannot import pyspark classes. Make sure that PySpark " |
| 81 | "is available in your environment. E.g., try running 'import pyspark'" |
| 82 | ) |
| 83 | validate_timeout(timeout) |
| 84 | validate_loss_threshold(loss_threshold) |
| 85 | self._spark = ( |
| 86 | SparkSession.builder.getOrCreate() |
| 87 | if spark_session is None |
| 88 | else spark_session |
| 89 | ) |
| 90 | self._spark_context = self._spark.sparkContext |
| 91 | self._spark_pinned_threads_enabled = isinstance( |
| 92 | self._spark_context._gateway, ClientServer |
| 93 | ) |
| 94 | # The feature to support controlling jobGroupIds is in SPARK-22340 |
| 95 | self._spark_supports_job_cancelling = ( |
| 96 | self._spark_pinned_threads_enabled |
| 97 | or hasattr(self._spark_context.parallelize([1]), "collectWithJobGroup") |
| 98 | ) |
| 99 | spark_default_parallelism = self._spark_context.defaultParallelism |
| 100 | self.parallelism = self._decide_parallelism( |
| 101 | requested_parallelism=parallelism, |
| 102 | spark_default_parallelism=spark_default_parallelism, |
| 103 | ) |
| 104 | |
| 105 | if not self._spark_supports_job_cancelling and timeout is not None: |
| 106 | logger.warning( |
| 107 | "SparkTrials was constructed with a timeout specified, but this Apache " |
| 108 | "Spark version does not support job group-based cancellation. The " |
| 109 | "timeout will be respected when starting new Spark jobs, but " |
| 110 | "SparkTrials will not be able to cancel running Spark jobs which exceed" |
| 111 | " the timeout." |
| 112 | ) |
| 113 | |
| 114 | self.timeout = timeout |
nothing calls this directly
no test coverage detected