MCPcopy
hub / github.com/openai/improved-diffusion / mean_flat

Function mean_flat

improved_diffusion/nn.py:86–90  ·  view source on GitHub ↗

Take the mean over all non-batch dimensions.

(tensor)

Source from the content-addressed store, hash-verified

84
85
86def mean_flat(tensor):
87 """
88 Take the mean over all non-batch dimensions.
89 """
90 return tensor.mean(dim=list(range(1, len(tensor.shape))))
91
92
93def normalization(channels):

Callers 4

_vb_terms_bpdMethod · 0.85
training_lossesMethod · 0.85
_prior_bpdMethod · 0.85
calc_bpd_loopMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected