| 30 | # M3 Max is 400 GB/s, so 400/2.6 = ~154 tok/s |
| 31 | |
| 32 | def fetch_weights() -> dict[str, Tensor]: |
| 33 | # TODO: make this lazy so the 3 fetches can happen in parallel |
| 34 | m1 = Tensor.from_url("https://huggingface.co/allenai/OLMoE-1B-7B-0924/resolve/main/model-00001-of-00003.safetensors").to(Device.DEFAULT) |
| 35 | m2 = Tensor.from_url("https://huggingface.co/allenai/OLMoE-1B-7B-0924/resolve/main/model-00002-of-00003.safetensors").to(Device.DEFAULT) |
| 36 | m3 = Tensor.from_url("https://huggingface.co/allenai/OLMoE-1B-7B-0924/resolve/main/model-00003-of-00003.safetensors").to(Device.DEFAULT) |
| 37 | return {**nn.state.safe_load(m1), **nn.state.safe_load(m2), **nn.state.safe_load(m3)} |
| 38 | |
| 39 | if __name__ == "__main__": |
| 40 | if getenv("TORCH"): |