MCPcopy
hub / github.com/huggingface/peft / from_pretrained

Method from_pretrained

tests/test_custom_models.py:1877–1929  ·  view source on GitHub ↗
(cls, model_id, dtype=None)

Source from the content-addressed store, hash-verified

1875
1876 @classmethod
1877 def from_pretrained(cls, model_id, dtype=None):
1878 # set the seed so that from_pretrained always returns the same model
1879 torch.manual_seed(0)
1880
1881 if dtype is None:
1882 dtype = torch.float32
1883
1884 if model_id == "MLP":
1885 return MLP().to(dtype)
1886
1887 if model_id == "EmbConv1D":
1888 return ModelEmbConv1D().to(dtype)
1889
1890 if model_id == "Conv1d":
1891 return ModelConv1D().to(dtype)
1892
1893 if model_id == "Conv1dBigger":
1894 return ModelConv1DBigger().to(dtype)
1895
1896 if model_id == "Conv2d":
1897 return ModelConv2D().to(dtype)
1898
1899 if model_id == "Conv2d1x1":
1900 return ModelConv2D1x1().to(dtype)
1901
1902 if model_id == "Conv1dKernel1":
1903 return ModelConv1DKernel1().to(dtype)
1904
1905 if model_id == "Conv2dGroups":
1906 return ModelConv2DGroups().to(dtype)
1907
1908 if model_id == "Conv2dGroups2":
1909 return ModelConv2DGroups2().to(dtype)
1910
1911 if model_id == "Conv3d":
1912 return ModelConv3D().to(dtype)
1913
1914 if model_id == "MLP_LayerNorm":
1915 return MLP_LayerNorm().to(dtype)
1916
1917 if model_id == "MLP2":
1918 return MLP2().to(dtype)
1919
1920 if model_id == "Conv2d2":
1921 return ModelConv2D2().to(dtype)
1922
1923 if model_id == "MHA":
1924 return ModelMha().to(dtype)
1925
1926 if model_id == "MlpUsingParameters":
1927 return MlpUsingParameters().to(dtype)
1928
1929 raise ValueError(f"model_id {model_id} not implemented")
1930
1931
1932class TestPeftCustomModel(PeftCommonTester):

Calls 15

ModelConv1DClass · 0.85
ModelConv1DBiggerClass · 0.85
ModelConv2D1x1Class · 0.85
ModelConv1DKernel1Class · 0.85
ModelConv2DGroups2Class · 0.85
ModelConv3DClass · 0.85
MLP_LayerNormClass · 0.85
MLP2Class · 0.85
ModelConv2D2Class · 0.85
MlpUsingParametersClass · 0.85
MLPClass · 0.70
ModelEmbConv1DClass · 0.70

Tested by

no test coverage detected