Sharded computation followed by concat/mean for TPUStrategy.
(fn, *, args, reduction, strategy)
| 28 | |
| 29 | |
| 30 | def distributed(fn, *, args, reduction, strategy): |
| 31 | """ |
| 32 | Sharded computation followed by concat/mean for TPUStrategy. |
| 33 | """ |
| 34 | out = strategy.experimental_run_v2(fn, args=args) |
| 35 | if reduction == 'mean': |
| 36 | return tf.nest.map_structure(lambda x: tf.reduce_mean(strategy.reduce('mean', x)), out) |
| 37 | elif reduction == 'concat': |
| 38 | return tf.nest.map_structure(lambda x: tf.concat(strategy.experimental_local_results(x), axis=0), out) |
| 39 | else: |
| 40 | raise NotImplementedError(reduction) |
| 41 | |
| 42 | |
| 43 | # ========== Inception utilities ========== |
no outgoing calls
no test coverage detected