MCPcopy
hub / github.com/huggingface/diffusers / generate_test_file

Function generate_test_file

utils/generate_model_tests.py:432–492  ·  view source on GitHub ↗
(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str])

Source from the content-addressed store, hash-verified

430
431
432def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str:
433 model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "")
434 testers = determine_testers(model_info, include_optional, imports)
435 tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"})
436
437 lines = [
438 "# coding=utf-8",
439 "# Copyright 2025 HuggingFace Inc.",
440 "#",
441 '# Licensed under the Apache License, Version 2.0 (the "License");',
442 "# you may not use this file except in compliance with the License.",
443 "# You may obtain a copy of the License at",
444 "#",
445 "# http://www.apache.org/licenses/LICENSE-2.0",
446 "#",
447 "# Unless required by applicable law or agreed to in writing, software",
448 '# distributed under the License is distributed on an "AS IS" BASIS,',
449 "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
450 "# See the License for the specific language governing permissions and",
451 "# limitations under the License.",
452 "",
453 "import torch",
454 "",
455 f"from diffusers import {model_info['name']}",
456 "from diffusers.utils.torch_utils import randn_tensor",
457 "",
458 "from ...testing_utils import enable_full_determinism, torch_device",
459 ]
460
461 if "LoraTesterMixin" in testers:
462 lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin")
463
464 lines.extend(
465 [
466 "from ..testing_utils import (",
467 *[f" {tester}," for tester in sorted(tester_imports)],
468 ")",
469 "",
470 "",
471 "enable_full_determinism()",
472 "",
473 "",
474 ]
475 )
476
477 config_class = f"{model_name}TesterConfig"
478 lines.append(generate_config_class(model_info, model_name))
479 lines.append("")
480 lines.append("")
481
482 for tester in testers:
483 lines.append(generate_test_class(model_name, config_class, tester))
484 lines.append("")
485 lines.append("")
486
487 if "LoraTesterMixin" in testers:
488 lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin"))
489 lines.append("")

Callers 1

mainFunction · 0.85

Calls 3

determine_testersFunction · 0.85
generate_config_classFunction · 0.85
generate_test_classFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…