()
| 176 | # Main |
| 177 | |
| 178 | def main(): |
| 179 | parser = argparse.ArgumentParser(description="Base model evaluation") |
| 180 | parser.add_argument('--eval', type=str, default='core,bpb,sample', help='Comma-separated evaluations to run: core,bpb,sample (default: all)') |
| 181 | parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path (e.g. openai-community/gpt2-xl)') |
| 182 | parser.add_argument('--model-tag', type=str, default=None, help='nanochat model tag to identify the checkpoint directory') |
| 183 | parser.add_argument('--step', type=int, default=None, help='Model step to load (default = last)') |
| 184 | parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per CORE task (-1 = all)') |
| 185 | parser.add_argument('--device-batch-size', type=int, default=32, help='Per-device batch size for BPB evaluation') |
| 186 | parser.add_argument('--split-tokens', type=int, default=40*524288, help='Number of tokens to evaluate per split for BPB') |
| 187 | parser.add_argument('--device-type', type=str, default='', help='cuda|cpu|mps (empty = autodetect)') |
| 188 | args = parser.parse_args() |
| 189 | |
| 190 | # Parse evaluation modes |
| 191 | eval_modes = set(mode.strip() for mode in args.eval.split(',')) |
| 192 | valid_modes = {'core', 'bpb', 'sample'} |
| 193 | invalid = eval_modes - valid_modes |
| 194 | if invalid: |
| 195 | parser.error(f"Invalid eval modes: {invalid}. Valid: {valid_modes}") |
| 196 | |
| 197 | # Distributed / precision setup |
| 198 | device_type = autodetect_device_type() if args.device_type == '' else args.device_type |
| 199 | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) |
| 200 | # Load model and tokenizer |
| 201 | is_hf_model = args.hf_path is not None |
| 202 | if is_hf_model: |
| 203 | model, tokenizer = load_hf_model(args.hf_path, device) |
| 204 | sequence_len = model.max_seq_len or 1024 |
| 205 | token_bytes = get_hf_token_bytes(tokenizer, device=device) |
| 206 | model_name = args.hf_path |
| 207 | model_slug = args.hf_path.replace("/", "-") |
| 208 | else: |
| 209 | model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step) |
| 210 | sequence_len = meta["model_config"]["sequence_len"] |
| 211 | token_bytes = get_token_bytes(device=device) |
| 212 | model_name = f"base_model (step {meta['step']})" |
| 213 | model_slug = f"base_model_{meta['step']:06d}" |
| 214 | |
| 215 | print0(f"Evaluating model: {model_name}") |
| 216 | print0(f"Eval modes: {', '.join(sorted(eval_modes))}") |
| 217 | |
| 218 | # Results to log |
| 219 | core_results = None |
| 220 | bpb_results = {} |
| 221 | samples = [] |
| 222 | unconditioned_samples = [] |
| 223 | |
| 224 | # --- Sampling --- |
| 225 | if 'sample' in eval_modes and not is_hf_model: |
| 226 | print0("\n" + "="*80) |
| 227 | print0("Model Samples") |
| 228 | print0("="*80) |
| 229 | if ddp_rank == 0: |
| 230 | prompts = [ |
| 231 | "The capital of France is", |
| 232 | "The chemical symbol of gold is", |
| 233 | "If yesterday was Friday, then tomorrow will be", |
| 234 | "The opposite of hot is", |
| 235 | "The planets of the solar system are:", |
no test coverage detected