Download the model checkpoint and tokenizer from S3 for testing The reason to download the model from S3 is to avoid downloading the model from HuggingFace hub during testing, which is flaky because of the rate limit and HF hub downtime. Args: remote_url: The remote URL
(
remote_url: str, file_list: List[str]
)
| 17 | |
| 18 | |
| 19 | def download_model_from_s3( |
| 20 | remote_url: str, file_list: List[str] |
| 21 | ) -> Generator[str, None, None]: |
| 22 | """ |
| 23 | Download the model checkpoint and tokenizer from S3 for testing |
| 24 | The reason to download the model from S3 is to avoid downloading the model |
| 25 | from HuggingFace hub during testing, which is flaky because of the rate |
| 26 | limit and HF hub downtime. |
| 27 | |
| 28 | Args: |
| 29 | remote_url: The remote URL to download the model from. |
| 30 | file_list: The list of files to download. |
| 31 | |
| 32 | Yields: |
| 33 | str: The path to the downloaded model checkpoint and tokenizer. |
| 34 | """ |
| 35 | with tempfile.TemporaryDirectory(prefix="ray-llm-test-model") as checkpoint_dir: |
| 36 | print(f"Downloading model from {remote_url} to {checkpoint_dir}", flush=True) |
| 37 | for file_name in file_list: |
| 38 | response = requests.get(remote_url + file_name) |
| 39 | with open(os.path.join(checkpoint_dir, file_name), "wb") as fp: |
| 40 | fp.write(response.content) |
| 41 | yield os.path.abspath(checkpoint_dir) |
| 42 | |
| 43 | |
| 44 | @pytest.fixture(scope="session") |
no test coverage detected
searching dependent graphs…