(cls, model_id, dtype=None)
| 1875 | |
| 1876 | @classmethod |
| 1877 | def from_pretrained(cls, model_id, dtype=None): |
| 1878 | # set the seed so that from_pretrained always returns the same model |
| 1879 | torch.manual_seed(0) |
| 1880 | |
| 1881 | if dtype is None: |
| 1882 | dtype = torch.float32 |
| 1883 | |
| 1884 | if model_id == "MLP": |
| 1885 | return MLP().to(dtype) |
| 1886 | |
| 1887 | if model_id == "EmbConv1D": |
| 1888 | return ModelEmbConv1D().to(dtype) |
| 1889 | |
| 1890 | if model_id == "Conv1d": |
| 1891 | return ModelConv1D().to(dtype) |
| 1892 | |
| 1893 | if model_id == "Conv1dBigger": |
| 1894 | return ModelConv1DBigger().to(dtype) |
| 1895 | |
| 1896 | if model_id == "Conv2d": |
| 1897 | return ModelConv2D().to(dtype) |
| 1898 | |
| 1899 | if model_id == "Conv2d1x1": |
| 1900 | return ModelConv2D1x1().to(dtype) |
| 1901 | |
| 1902 | if model_id == "Conv1dKernel1": |
| 1903 | return ModelConv1DKernel1().to(dtype) |
| 1904 | |
| 1905 | if model_id == "Conv2dGroups": |
| 1906 | return ModelConv2DGroups().to(dtype) |
| 1907 | |
| 1908 | if model_id == "Conv2dGroups2": |
| 1909 | return ModelConv2DGroups2().to(dtype) |
| 1910 | |
| 1911 | if model_id == "Conv3d": |
| 1912 | return ModelConv3D().to(dtype) |
| 1913 | |
| 1914 | if model_id == "MLP_LayerNorm": |
| 1915 | return MLP_LayerNorm().to(dtype) |
| 1916 | |
| 1917 | if model_id == "MLP2": |
| 1918 | return MLP2().to(dtype) |
| 1919 | |
| 1920 | if model_id == "Conv2d2": |
| 1921 | return ModelConv2D2().to(dtype) |
| 1922 | |
| 1923 | if model_id == "MHA": |
| 1924 | return ModelMha().to(dtype) |
| 1925 | |
| 1926 | if model_id == "MlpUsingParameters": |
| 1927 | return MlpUsingParameters().to(dtype) |
| 1928 | |
| 1929 | raise ValueError(f"model_id {model_id} not implemented") |
| 1930 | |
| 1931 | |
| 1932 | class TestPeftCustomModel(PeftCommonTester): |
no test coverage detected