Execute the full ReflACT training loop. Returns summary dict.
(self)
| 595 | self.adapter = adapter |
| 596 | |
| 597 | def train(self) -> dict: |
| 598 | """Execute the full ReflACT training loop. Returns summary dict.""" |
| 599 | cfg = self.cfg |
| 600 | adapter = self.adapter |
| 601 | out_root = cfg["out_root"] |
| 602 | os.makedirs(out_root, exist_ok=True) |
| 603 | |
| 604 | # ── Adapter setup (one-time init) ──────────────────────────── |
| 605 | adapter.setup(cfg) |
| 606 | dataloader = adapter.get_dataloader() |
| 607 | |
| 608 | def _build_train_env(batch: BatchSpec): |
| 609 | env_manager = adapter.build_env_from_batch(batch, out_root=out_root) |
| 610 | return env_manager, batch.batch_size, batch.seed |
| 611 | |
| 612 | def _build_eval_env(split: str, env_num: int, seed: int): |
| 613 | if dataloader is None: |
| 614 | env_manager = adapter.build_eval_env( |
| 615 | env_num=env_num, |
| 616 | split=split, |
| 617 | seed=seed, |
| 618 | out_root=out_root, |
| 619 | ) |
| 620 | actual_n = len(env_manager) if hasattr(env_manager, "__len__") else env_num |
| 621 | return env_manager, actual_n |
| 622 | |
| 623 | batch = dataloader.build_eval_batch( |
| 624 | env_num=env_num, |
| 625 | split=split, |
| 626 | seed=seed, |
| 627 | out_root=out_root, |
| 628 | ) |
| 629 | env_manager = adapter.build_env_from_batch(batch, out_root=out_root) |
| 630 | return env_manager, batch.batch_size |
| 631 | |
| 632 | # ── Configure models ───────────────────────────────────────────── |
| 633 | backend = cfg.get("model_backend", "azure_openai") |
| 634 | configure_azure_openai( |
| 635 | endpoint=( |
| 636 | cfg.get("azure_openai_endpoint") |
| 637 | or cfg.get("azure_endpoint") |
| 638 | or None |
| 639 | ), |
| 640 | api_version=( |
| 641 | cfg.get("azure_openai_api_version") |
| 642 | or cfg.get("azure_api_version") |
| 643 | or None |
| 644 | ), |
| 645 | api_key=( |
| 646 | cfg.get("azure_openai_api_key") |
| 647 | or cfg.get("azure_api_key") |
| 648 | or None |
| 649 | ), |
| 650 | auth_mode=cfg.get("azure_openai_auth_mode") or None, |
| 651 | ad_scope=cfg.get("azure_openai_ad_scope") or None, |
| 652 | managed_identity_client_id=cfg.get("azure_openai_managed_identity_client_id") or None, |
| 653 | optimizer_endpoint=cfg.get("optimizer_azure_openai_endpoint") or None, |
| 654 | optimizer_api_version=cfg.get("optimizer_azure_openai_api_version") or None, |
no test coverage detected