MCPcopy
hub / github.com/hojonathanho/diffusion / distributed

Function distributed

diffusion_tf/tpu_utils/tpu_utils.py:30–40  ·  view source on GitHub ↗

Sharded computation followed by concat/mean for TPUStrategy.

(fn, *, args, reduction, strategy)

Source from the content-addressed store, hash-verified

28
29
30def distributed(fn, *, args, reduction, strategy):
31 """
32 Sharded computation followed by concat/mean for TPUStrategy.
33 """
34 out = strategy.experimental_run_v2(fn, args=args)
35 if reduction == 'mean':
36 return tf.nest.map_structure(lambda x: tf.reduce_mean(strategy.reduce('mean', x)), out)
37 elif reduction == 'concat':
38 return tf.nest.map_structure(lambda x: tf.concat(strategy.experimental_local_results(x), axis=0), out)
39 else:
40 raise NotImplementedError(reduction)
41
42
43# ========== Inception utilities ==========

Callers 4

_make_bpd_graphMethod · 0.85
__init__Method · 0.85
_make_sampling_graphMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected