MCPcopy Index your code
hub / github.com/ml-explore/mlx-examples / convert

Function convert

segment_anything/convert.py:44–65  ·  view source on GitHub ↗
(model_path)

Source from the content-addressed store, hash-verified

42
43
44def convert(model_path):
45 weight_file = str(model_path / "model.safetensors")
46 weights = mx.load(weight_file)
47
48 mlx_weights = dict()
49 for k, v in weights.items():
50 if k in {
51 "vision_encoder.patch_embed.projection.weight",
52 "vision_encoder.neck.conv1.weight",
53 "vision_encoder.neck.conv2.weight",
54 "prompt_encoder.mask_embed.conv1.weight",
55 "prompt_encoder.mask_embed.conv2.weight",
56 "prompt_encoder.mask_embed.conv3.weight",
57 }:
58 v = v.transpose(0, 2, 3, 1)
59 if k in {
60 "mask_decoder.upscale_conv1.weight",
61 "mask_decoder.upscale_conv2.weight",
62 }:
63 v = v.transpose(1, 2, 3, 0)
64 mlx_weights[k] = v
65 return mlx_weights
66
67
68if __name__ == "__main__":

Callers 1

convert.pyFile · 0.70

Calls 1

itemsMethod · 0.80

Tested by

no test coverage detected