| 612 | |
| 613 | |
| 614 | class EngineConfig: |
| 615 | |
| 616 | def __init__(self, pretrained_config: 'PretrainedConfig', |
| 617 | build_config: 'BuildConfig', version: str): |
| 618 | self.pretrained_config = pretrained_config |
| 619 | self.build_config = build_config |
| 620 | self.version = version |
| 621 | |
| 622 | @classmethod |
| 623 | def from_json_file(cls, config_file): |
| 624 | with open(config_file) as f: |
| 625 | return cls.from_json_str(f.read()) |
| 626 | |
| 627 | @classmethod |
| 628 | def from_json_str(cls, config_str): |
| 629 | config = json.loads(config_str) |
| 630 | return cls(PretrainedConfig.from_dict(config['pretrained_config']), |
| 631 | BuildConfig(**config['build_config']), config['version']) |
| 632 | |
| 633 | def to_dict(self): |
| 634 | build_config = self.build_config.model_dump(mode="json") |
| 635 | build_config.pop('dry_run', None) # Not an Engine Characteristic |
| 636 | build_config.pop('visualize_network', |
| 637 | None) # Not an Engine Characteristic |
| 638 | return { |
| 639 | 'version': self.version, |
| 640 | 'pretrained_config': self.pretrained_config.to_dict(), |
| 641 | 'build_config': build_config, |
| 642 | } |
| 643 | |
| 644 | |
| 645 | class Engine: |