()
| 333 | |
| 334 | |
| 335 | def parse_args(): |
| 336 | parser = argparse.ArgumentParser() |
| 337 | parser.add_argument( |
| 338 | "--data_dir", |
| 339 | type=str, |
| 340 | default="data/mmlu", |
| 341 | help=("Path to the data directory. If not available, " |
| 342 | "download https://people.eecs.berkeley.edu/~hendrycks/data.tar"), |
| 343 | ) |
| 344 | parser.add_argument("--ntrain", type=int, default=5) |
| 345 | parser.add_argument("--max_input_length", type=int, default=2048) |
| 346 | parser.add_argument("--test_trt_llm", action="store_true") |
| 347 | parser.add_argument("--test_hf", action="store_true") |
| 348 | parser.add_argument('--check_accuracy', action='store_true') |
| 349 | parser.add_argument('--accuracy_threshold', type=float, default=30) |
| 350 | parser.add_argument('--max_ite', type=int, default=10000000) |
| 351 | parser = add_common_args(parser) |
| 352 | |
| 353 | args = parser.parse_args() |
| 354 | |
| 355 | return args |
| 356 | |
| 357 | |
| 358 | def main(): |
no test coverage detected