(model_info: dict, model_name: str)
| 244 | |
| 245 | |
| 246 | def generate_config_class(model_info: dict, model_name: str) -> str: |
| 247 | class_name = f"{model_name}TesterConfig" |
| 248 | model_class = model_info["name"] |
| 249 | forward_params = model_info.get("forward_params", []) |
| 250 | init_params = model_info.get("init_params", []) |
| 251 | |
| 252 | lines = [ |
| 253 | f"class {class_name}:", |
| 254 | " @property", |
| 255 | " def model_class(self):", |
| 256 | f" return {model_class}", |
| 257 | "", |
| 258 | " @property", |
| 259 | " def pretrained_model_name_or_path(self):", |
| 260 | ' return "" # TODO: Set Hub repository ID', |
| 261 | "", |
| 262 | " @property", |
| 263 | " def pretrained_model_kwargs(self):", |
| 264 | ' return {"subfolder": "transformer"}', |
| 265 | "", |
| 266 | " @property", |
| 267 | " def generator(self):", |
| 268 | ' return torch.Generator("cpu").manual_seed(0)', |
| 269 | "", |
| 270 | " def get_init_dict(self) -> dict[str, int | list[int]]:", |
| 271 | ] |
| 272 | |
| 273 | if init_params: |
| 274 | lines.append(" # __init__ parameters:") |
| 275 | for param in init_params: |
| 276 | type_str = f": {param['type']}" if param["type"] else "" |
| 277 | default_str = f" = {param['default']}" if param["default"] is not None else "" |
| 278 | lines.append(f" # {param['name']}{type_str}{default_str}") |
| 279 | |
| 280 | lines.extend( |
| 281 | [ |
| 282 | " return {}", |
| 283 | "", |
| 284 | " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", |
| 285 | ] |
| 286 | ) |
| 287 | |
| 288 | if forward_params: |
| 289 | lines.append(" # forward() parameters:") |
| 290 | for param in forward_params: |
| 291 | type_str = f": {param['type']}" if param["type"] else "" |
| 292 | default_str = f" = {param['default']}" if param["default"] is not None else "" |
| 293 | lines.append(f" # {param['name']}{type_str}{default_str}") |
| 294 | |
| 295 | lines.extend( |
| 296 | [ |
| 297 | " # TODO: Fill in dummy inputs", |
| 298 | " return {}", |
| 299 | "", |
| 300 | " @property", |
| 301 | " def input_shape(self) -> tuple[int, ...]:", |
| 302 | " return (1, 1)", |
| 303 | "", |
no test coverage detected
searching dependent graphs…