(model_size, model_dir)
| 11 | |
| 12 | |
| 13 | def download_gpt2_files(model_size, model_dir): |
| 14 | assert model_size in ["124M", "355M", "774M", "1558M"] |
| 15 | for filename in [ |
| 16 | "checkpoint", |
| 17 | "encoder.json", |
| 18 | "hparams.json", |
| 19 | "model.ckpt.data-00000-of-00001", |
| 20 | "model.ckpt.index", |
| 21 | "model.ckpt.meta", |
| 22 | "vocab.bpe", |
| 23 | ]: |
| 24 | url = "https://openaipublic.blob.core.windows.net/gpt-2/models" |
| 25 | r = requests.get(f"{url}/{model_size}/{filename}", stream=True) |
| 26 | r.raise_for_status() |
| 27 | |
| 28 | with open(os.path.join(model_dir, filename), "wb") as f: |
| 29 | file_size = int(r.headers["content-length"]) |
| 30 | chunk_size = 1000 |
| 31 | with tqdm( |
| 32 | ncols=100, |
| 33 | desc="Fetching " + filename, |
| 34 | total=file_size, |
| 35 | unit_scale=True, |
| 36 | unit="b", |
| 37 | ) as pbar: |
| 38 | # 1k for chunk_size, since Ethernet packet size is around 1500 bytes |
| 39 | for chunk in r.iter_content(chunk_size=chunk_size): |
| 40 | f.write(chunk) |
| 41 | pbar.update(chunk_size) |
| 42 | |
| 43 | |
| 44 | def load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams): |
no outgoing calls
no test coverage detected