(model_size, models_dir)
| 66 | |
| 67 | |
| 68 | def load_encoder_hparams_and_params(model_size, models_dir): |
| 69 | assert model_size in ["124M", "355M", "774M", "1558M"] |
| 70 | |
| 71 | model_dir = os.path.join(models_dir, model_size) |
| 72 | tf_ckpt_path = tf.train.latest_checkpoint(model_dir) |
| 73 | if not tf_ckpt_path: # download files if necessary |
| 74 | os.makedirs(model_dir, exist_ok=True) |
| 75 | download_gpt2_files(model_size, model_dir) |
| 76 | tf_ckpt_path = tf.train.latest_checkpoint(model_dir) |
| 77 | |
| 78 | encoder = get_encoder(model_size, models_dir) |
| 79 | hparams = json.load(open(os.path.join(model_dir, "hparams.json"))) |
| 80 | params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams) |
| 81 | |
| 82 | return encoder, hparams, params |
no test coverage detected