(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str])
| 430 | |
| 431 | |
| 432 | def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str: |
| 433 | model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "") |
| 434 | testers = determine_testers(model_info, include_optional, imports) |
| 435 | tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"}) |
| 436 | |
| 437 | lines = [ |
| 438 | "# coding=utf-8", |
| 439 | "# Copyright 2025 HuggingFace Inc.", |
| 440 | "#", |
| 441 | '# Licensed under the Apache License, Version 2.0 (the "License");', |
| 442 | "# you may not use this file except in compliance with the License.", |
| 443 | "# You may obtain a copy of the License at", |
| 444 | "#", |
| 445 | "# http://www.apache.org/licenses/LICENSE-2.0", |
| 446 | "#", |
| 447 | "# Unless required by applicable law or agreed to in writing, software", |
| 448 | '# distributed under the License is distributed on an "AS IS" BASIS,', |
| 449 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", |
| 450 | "# See the License for the specific language governing permissions and", |
| 451 | "# limitations under the License.", |
| 452 | "", |
| 453 | "import torch", |
| 454 | "", |
| 455 | f"from diffusers import {model_info['name']}", |
| 456 | "from diffusers.utils.torch_utils import randn_tensor", |
| 457 | "", |
| 458 | "from ...testing_utils import enable_full_determinism, torch_device", |
| 459 | ] |
| 460 | |
| 461 | if "LoraTesterMixin" in testers: |
| 462 | lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin") |
| 463 | |
| 464 | lines.extend( |
| 465 | [ |
| 466 | "from ..testing_utils import (", |
| 467 | *[f" {tester}," for tester in sorted(tester_imports)], |
| 468 | ")", |
| 469 | "", |
| 470 | "", |
| 471 | "enable_full_determinism()", |
| 472 | "", |
| 473 | "", |
| 474 | ] |
| 475 | ) |
| 476 | |
| 477 | config_class = f"{model_name}TesterConfig" |
| 478 | lines.append(generate_config_class(model_info, model_name)) |
| 479 | lines.append("") |
| 480 | lines.append("") |
| 481 | |
| 482 | for tester in testers: |
| 483 | lines.append(generate_test_class(model_name, config_class, tester)) |
| 484 | lines.append("") |
| 485 | lines.append("") |
| 486 | |
| 487 | if "LoraTesterMixin" in testers: |
| 488 | lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin")) |
| 489 | lines.append("") |
no test coverage detected
searching dependent graphs…