Args: model_name: actor_override_config_kwargs: Returns:
(model_name: str, override_config_kwargs=None, automodel_kwargs=None)
| 56 | |
| 57 | |
| 58 | def create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module: |
| 59 | """ |
| 60 | |
| 61 | Args: |
| 62 | model_name: |
| 63 | actor_override_config_kwargs: |
| 64 | |
| 65 | Returns: |
| 66 | |
| 67 | """ |
| 68 | if override_config_kwargs is None: |
| 69 | override_config_kwargs = {} |
| 70 | if automodel_kwargs is None: |
| 71 | automodel_kwargs = {} |
| 72 | assert isinstance(override_config_kwargs, Dict), \ |
| 73 | f'override_config_kwargs must be a dict, got {type(override_config_kwargs)}' |
| 74 | module_config = get_huggingface_actor_config(model_name, |
| 75 | override_config_kwargs, |
| 76 | trust_remote_code=automodel_kwargs.get('trust_remote_code', False)) |
| 77 | module: nn.Module = AutoModelForCausalLM.from_config(module_config, **automodel_kwargs) |
| 78 | return module |
| 79 | |
| 80 | |
| 81 | def create_huggingface_critic(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module: |
no test coverage detected