(
model: str,
)
| 37 | |
| 38 | |
| 39 | def get_model_class( |
| 40 | model: str, |
| 41 | ) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]: |
| 42 | configs = PretrainedConfig.get_config_dict(model) |
| 43 | |
| 44 | if any([("vision_config" in config) for config in configs]): |
| 45 | return AutoModelForImageTextToText |
| 46 | else: |
| 47 | return AutoModelForCausalLM |
| 48 | |
| 49 | |
| 50 | @dataclass |
no outgoing calls
no test coverage detected