Get the cache directory for TabPFN models, as appropriate for the platform.
()
| 407 | |
| 408 | |
| 409 | def get_cache_dir() -> Path: # noqa: PLR0911 |
| 410 | """Get the cache directory for TabPFN models, as appropriate for the platform.""" |
| 411 | if settings.tabpfn.model_cache_dir is not None: |
| 412 | return settings.tabpfn.model_cache_dir |
| 413 | |
| 414 | platform = sys.platform |
| 415 | appname = "tabpfn" |
| 416 | use_instead_path = (Path.cwd() / ".tabpfn_models").resolve() |
| 417 | |
| 418 | if platform == "win32": |
| 419 | # Do something similar to platformdirs, but very simplified: |
| 420 | # https://github.com/tox-dev/platformdirs/blob/b769439b2a3b70769a93905944a71b3e63ef4823/src/platformdirs/windows.py#L252-L265 |
| 421 | # Unclear how well this works. |
| 422 | APPDATA_PATH = os.environ.get("APPDATA", "") |
| 423 | if APPDATA_PATH.strip() != "": |
| 424 | return Path(APPDATA_PATH) / appname |
| 425 | |
| 426 | warnings.warn( |
| 427 | "Could not find APPDATA environment variable to get user cache dir," |
| 428 | " but detected platform 'win32'." |
| 429 | f" Defaulting to a path '{use_instead_path}'." |
| 430 | " If you would prefer, please specify a directory when creating" |
| 431 | " the model.", |
| 432 | UserWarning, |
| 433 | stacklevel=2, |
| 434 | ) |
| 435 | return use_instead_path |
| 436 | |
| 437 | if platform == "darwin": |
| 438 | return Path.home() / "Library" / "Caches" / appname |
| 439 | |
| 440 | # TODO: Not entirely sure here, Python doesn't explicitly list |
| 441 | # all of these and defaults to the underlying operating system |
| 442 | # if not sure. |
| 443 | linux_likes = ("freebsd", "linux", "netbsd", "openbsd") |
| 444 | if any(platform.startswith(linux) for linux in linux_likes): |
| 445 | # The reason to use "" as default is that the env var could exist but be empty. |
| 446 | # We catch all this with the `.strip() != ""` below |
| 447 | XDG_CACHE_HOME = os.environ.get("XDG_CACHE_HOME", "") |
| 448 | if XDG_CACHE_HOME.strip() != "": |
| 449 | return Path(XDG_CACHE_HOME) / appname |
| 450 | return Path.home() / ".cache" / appname |
| 451 | |
| 452 | warnings.warn( |
| 453 | f"Unknown platform '{platform}' to get user cache dir." |
| 454 | f" Defaulting to a path at the execution site '{use_instead_path}'." |
| 455 | " If you would prefer, please specify a directory when creating" |
| 456 | " the model.", |
| 457 | UserWarning, |
| 458 | stacklevel=2, |
| 459 | ) |
| 460 | return use_instead_path |
| 461 | |
| 462 | |
| 463 | def download_model( |
no test coverage detected