| 91 | |
| 92 | |
| 93 | def main(argv: Sequence[str]) -> None: |
| 94 | if len(argv) > 1: |
| 95 | raise app.UsageError('Too many command-line arguments.') |
| 96 | |
| 97 | # Get a list of all existing loggers |
| 98 | # logging.root.manager.loggerDict contains all named loggers |
| 99 | # logging.getLogger(name) retrieves the logger object |
| 100 | loggers = [ |
| 101 | logging.getLogger(name) for name in logging.root.manager.loggerDict |
| 102 | ] |
| 103 | |
| 104 | # Iterate through the loggers and set their level to WARNING |
| 105 | for logger in loggers: |
| 106 | logger.setLevel(logging.WARNING) |
| 107 | |
| 108 | types.logger.addFilter(gepa_utils.FilterInferenceWarnings()) |
| 109 | output_dir = os.path.join( |
| 110 | _OUTPUT_DIR.value, datetime.now().strftime('%Y%m%d%H%M%S%f') |
| 111 | ) |
| 112 | os.makedirs(output_dir) |
| 113 | logging.info('Writing to output_dir=%s', output_dir) |
| 114 | config = experiment.ExperimentConfig( |
| 115 | tau_bench_env='retail', |
| 116 | agent_model='gemini-2.5-flash', |
| 117 | agent_model_provider='vertex_ai', |
| 118 | user_model='gemini-2.5-flash', |
| 119 | user_model_provider='vertex_ai', |
| 120 | max_concurrency=_MAX_CONCURRENCY.value, |
| 121 | num_eval_trials=_NUM_EVAL_TRIALS.value, |
| 122 | rnd_seed=42, |
| 123 | max_metric_calls=_MAX_METRIC_CALLS.value, |
| 124 | reflection_model='gemini-2.5-pro', |
| 125 | reflection_minibatch_size=_TRAIN_BATCH_SIZE.value, |
| 126 | use_rater=_USE_RATER.value, |
| 127 | feedback_dataset=experiment.Dataset(split='train'), |
| 128 | pareto_dataset=experiment.Dataset( |
| 129 | split='dev', max_size=_EVAL_SET_SIZE.value |
| 130 | ), |
| 131 | eval_dataset=experiment.Dataset( |
| 132 | split='test', max_size=_NUM_TEST_RECORDS.value |
| 133 | ), |
| 134 | ) |
| 135 | json.dump( |
| 136 | dataclasses.asdict(config), |
| 137 | open(os.path.join(output_dir, 'config.json'), 'w'), |
| 138 | ) |
| 139 | logging.info('Using config=%s', config) |
| 140 | |
| 141 | if _EVAL_MODE.value: |
| 142 | return experiment.run_eval( |
| 143 | output_dir=output_dir, |
| 144 | instructions=experiment.SEED_SYSTEM_INSTRUCTION, |
| 145 | config=config, |
| 146 | ) |
| 147 | |
| 148 | results = experiment.run_gepa( |
| 149 | config=config, |
| 150 | seed_instructions=experiment.SEED_SYSTEM_INSTRUCTION, |