(args)
| 134 | |
| 135 | |
| 136 | def main(args): |
| 137 | validate_arguments(args) |
| 138 | num_workers = len(GPU_IDS) if GPU_IDS else args.cpu_trainers |
| 139 | print('Using program %s with args %s' % (args.file, ' '.join(args.args))) |
| 140 | print('Using %d workers, %d parameter servers, %d GPUs.' % (num_workers, args.num_pss, len(GPU_IDS))) |
| 141 | cluster_spec = { |
| 142 | 'ps': ['localhost: %d' % (PORT_BASE + i) for i in range(args.num_pss)], |
| 143 | 'worker': ['localhost: %d' % (PORT_BASE + args.num_pss + i) for i in range(num_workers)] |
| 144 | } |
| 145 | processes = list(create_tf_jobs(cluster_spec, args.file, args.args)) |
| 146 | try: |
| 147 | print('Press ENTER to exit the training ...') |
| 148 | sys.stdin.readline() |
| 149 | except KeyboardInterrupt: # https://docs.python.org/3/library/exceptions.html#KeyboardInterrupt |
| 150 | print('Keyboard interrupt received') |
| 151 | finally: |
| 152 | print('stopping all subprocesses ...') |
| 153 | for p in processes: |
| 154 | p.kill() |
| 155 | for p in processes: |
| 156 | p.wait() |
| 157 | print('END') |
| 158 | |
| 159 | |
| 160 | def build_arg_parser(parser): |
no test coverage detected
searching dependent graphs…