Ensure that the data has dtype with n_components. Input data of shape (n_components, n, m) is converted to an array of shape (n, m) with data type np.dtype(f'{data.dtype}, ' * n_components) Complex data is returned as a view with dtype np.dtype('float64, float64') or np.dtype('f
(data, n_components)
| 833 | |
| 834 | |
| 835 | def _ensure_multivariate_data(data, n_components): |
| 836 | """ |
| 837 | Ensure that the data has dtype with n_components. |
| 838 | Input data of shape (n_components, n, m) is converted to an array of shape |
| 839 | (n, m) with data type np.dtype(f'{data.dtype}, ' * n_components) |
| 840 | Complex data is returned as a view with dtype np.dtype('float64, float64') |
| 841 | or np.dtype('float32, float32') |
| 842 | If n_components is 1 and data is not of type np.ndarray (i.e. PIL.Image), |
| 843 | the data is returned unchanged. |
| 844 | If data is None, the function returns None |
| 845 | |
| 846 | Parameters |
| 847 | ---------- |
| 848 | n_components : int |
| 849 | Number of variates in the data. |
| 850 | data : np.ndarray, PIL.Image or None |
| 851 | |
| 852 | Returns |
| 853 | ------- |
| 854 | np.ndarray, PIL.Image or None |
| 855 | """ |
| 856 | |
| 857 | if isinstance(data, np.ndarray): |
| 858 | if len(data.dtype.descr) == n_components: |
| 859 | # pass scalar data |
| 860 | # and already formatted data |
| 861 | return data |
| 862 | elif data.dtype in [np.complex64, np.complex128]: |
| 863 | if n_components != 2: |
| 864 | raise ValueError("Invalid data entry for multivariate data. " |
| 865 | "Complex numbers are incompatible with " |
| 866 | f"{n_components} variates.") |
| 867 | |
| 868 | # pass complex data |
| 869 | if data.dtype == np.complex128: |
| 870 | dt = np.dtype('float64, float64') |
| 871 | else: |
| 872 | dt = np.dtype('float32, float32') |
| 873 | |
| 874 | reconstructed = np.ma.array(np.ma.getdata(data).view(dt)) |
| 875 | if np.ma.is_masked(data): |
| 876 | for descriptor in dt.descr: |
| 877 | reconstructed[descriptor[0]][data.mask] = np.ma.masked |
| 878 | return reconstructed |
| 879 | |
| 880 | if n_components > 1 and len(data) == n_components: |
| 881 | # convert data from shape (n_components, n, m) |
| 882 | # to (n, m) with a new dtype |
| 883 | data = [np.ma.array(part, copy=False) for part in data] |
| 884 | dt = np.dtype(', '.join([f'{part.dtype}' for part in data])) |
| 885 | fields = [descriptor[0] for descriptor in dt.descr] |
| 886 | reconstructed = np.ma.empty(data[0].shape, dtype=dt) |
| 887 | for i, f in enumerate(fields): |
| 888 | if data[i].shape != reconstructed.shape: |
| 889 | raise ValueError("For multivariate data all variates must have same " |
| 890 | f"shape, not {data[0].shape} and {data[i].shape}") |
| 891 | reconstructed[f] = data[i] |
| 892 | if np.ma.is_masked(data[i]): |