(model_name: str, config_class: str, tester: str)
| 311 | |
| 312 | |
| 313 | def generate_test_class(model_name: str, config_class: str, tester: str) -> str: |
| 314 | tester_short = tester.replace("TesterMixin", "") |
| 315 | class_name = f"Test{model_name}{tester_short}" |
| 316 | |
| 317 | lines = [f"class {class_name}({config_class}, {tester}):"] |
| 318 | |
| 319 | if tester == "TorchCompileTesterMixin": |
| 320 | lines.extend( |
| 321 | [ |
| 322 | " @property", |
| 323 | " def different_shapes_for_compilation(self):", |
| 324 | " return [(4, 4), (4, 8), (8, 8)]", |
| 325 | "", |
| 326 | " def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:", |
| 327 | " # TODO: Implement dynamic input generation", |
| 328 | " return {}", |
| 329 | ] |
| 330 | ) |
| 331 | elif tester == "IPAdapterTesterMixin": |
| 332 | lines.extend( |
| 333 | [ |
| 334 | " @property", |
| 335 | " def ip_adapter_processor_cls(self):", |
| 336 | " return None # TODO: Set processor class", |
| 337 | "", |
| 338 | " def modify_inputs_for_ip_adapter(self, model, inputs_dict):", |
| 339 | " # TODO: Add IP adapter image embeds to inputs", |
| 340 | " return inputs_dict", |
| 341 | "", |
| 342 | " def create_ip_adapter_state_dict(self, model):", |
| 343 | " # TODO: Create IP adapter state dict", |
| 344 | " return {}", |
| 345 | ] |
| 346 | ) |
| 347 | elif tester == "SingleFileTesterMixin": |
| 348 | lines.extend( |
| 349 | [ |
| 350 | " @property", |
| 351 | " def ckpt_path(self):", |
| 352 | ' return "" # TODO: Set checkpoint path', |
| 353 | "", |
| 354 | " @property", |
| 355 | " def alternate_ckpt_paths(self):", |
| 356 | " return []", |
| 357 | "", |
| 358 | " @property", |
| 359 | " def pretrained_model_name_or_path(self):", |
| 360 | ' return "" # TODO: Set Hub repository ID', |
| 361 | ] |
| 362 | ) |
| 363 | elif tester == "GGUFTesterMixin": |
| 364 | lines.extend( |
| 365 | [ |
| 366 | " @property", |
| 367 | " def gguf_filename(self):", |
| 368 | ' return "" # TODO: Set GGUF filename', |
| 369 | "", |
| 370 | " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", |
no outgoing calls
no test coverage detected
searching dependent graphs…