Take the mean over all non-batch dimensions.
(tensor)
| 84 | |
| 85 | |
| 86 | def mean_flat(tensor): |
| 87 | """ |
| 88 | Take the mean over all non-batch dimensions. |
| 89 | """ |
| 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) |
| 91 | |
| 92 | |
| 93 | def normalization(channels): |
no outgoing calls
no test coverage detected