| 42 | |
| 43 | |
| 44 | def convert(model_path): |
| 45 | weight_file = str(model_path / "model.safetensors") |
| 46 | weights = mx.load(weight_file) |
| 47 | |
| 48 | mlx_weights = dict() |
| 49 | for k, v in weights.items(): |
| 50 | if k in { |
| 51 | "vision_encoder.patch_embed.projection.weight", |
| 52 | "vision_encoder.neck.conv1.weight", |
| 53 | "vision_encoder.neck.conv2.weight", |
| 54 | "prompt_encoder.mask_embed.conv1.weight", |
| 55 | "prompt_encoder.mask_embed.conv2.weight", |
| 56 | "prompt_encoder.mask_embed.conv3.weight", |
| 57 | }: |
| 58 | v = v.transpose(0, 2, 3, 1) |
| 59 | if k in { |
| 60 | "mask_decoder.upscale_conv1.weight", |
| 61 | "mask_decoder.upscale_conv2.weight", |
| 62 | }: |
| 63 | v = v.transpose(1, 2, 3, 0) |
| 64 | mlx_weights[k] = v |
| 65 | return mlx_weights |
| 66 | |
| 67 | |
| 68 | if __name__ == "__main__": |