Register a validator to check directory flags. Args: flag_names: An iterable of strings containing the names of flags to be checked.
(flag_names)
| 21 | |
| 22 | |
| 23 | def require_cloud_storage(flag_names): |
| 24 | """Register a validator to check directory flags. |
| 25 | |
| 26 | Args: |
| 27 | flag_names: An iterable of strings containing the names of flags to be |
| 28 | checked. |
| 29 | """ |
| 30 | msg = "TPU requires GCS path for {}".format(", ".join(flag_names)) |
| 31 | |
| 32 | @flags.multi_flags_validator(["tpu"] + flag_names, message=msg) |
| 33 | def _path_check(flag_values): # pylint: disable=missing-docstring |
| 34 | if flag_values["tpu"] is None: |
| 35 | return True |
| 36 | |
| 37 | valid_flags = True |
| 38 | for key in flag_names: |
| 39 | if not flag_values[key].startswith("gs://"): |
| 40 | logging.error("%s must be a GCS path.", key) |
| 41 | valid_flags = False |
| 42 | |
| 43 | return valid_flags |
| 44 | |
| 45 | |
| 46 | def define_device(tpu=True): |