| 16 | |
| 17 | |
| 18 | def _import_external_ops(message): |
| 19 | if "horovod" in message.lower(): |
| 20 | logger.info("Importing horovod ...") |
| 21 | import horovod.tensorflow # noqa |
| 22 | return |
| 23 | if "MaxBytesInUse" in message: |
| 24 | logger.info("Importing memory_stats ...") |
| 25 | from tensorflow.contrib.memory_stats import MaxBytesInUse # noqa |
| 26 | return |
| 27 | if 'Nccl' in message: |
| 28 | logger.info("Importing nccl ...") |
| 29 | if TF_version <= (1, 12): |
| 30 | try: |
| 31 | from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so |
| 32 | except Exception: |
| 33 | pass |
| 34 | else: |
| 35 | _validate_and_load_nccl_so() |
| 36 | from tensorflow.contrib.nccl.ops import gen_nccl_ops # noqa |
| 37 | else: |
| 38 | from tensorflow.python.ops import gen_nccl_ops # noqa |
| 39 | return |
| 40 | if 'ZMQConnection' in message: |
| 41 | import zmq_ops # noqa |
| 42 | return |
| 43 | logger.error("Unhandled error: " + message) |
| 44 | |
| 45 | |
| 46 | def guess_inputs(input_dir): |