(model_info: dict, include_optional: list[str], imports: set[str])
| 205 | |
| 206 | |
| 207 | def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]: |
| 208 | testers = list(ALWAYS_INCLUDE_TESTERS) |
| 209 | |
| 210 | for base in model_info["bases"]: |
| 211 | if base in MIXIN_TO_TESTER: |
| 212 | tester = MIXIN_TO_TESTER[base] |
| 213 | if tester not in testers: |
| 214 | testers.append(tester) |
| 215 | |
| 216 | for attr, tester in ATTRIBUTE_TO_TESTER.items(): |
| 217 | if attr in model_info["attributes"]: |
| 218 | value = model_info["attributes"][attr] |
| 219 | if value is not None and value is not False: |
| 220 | if tester not in testers: |
| 221 | testers.append(tester) |
| 222 | |
| 223 | if "_cp_plan" in model_info["attributes"] and model_info["attributes"]["_cp_plan"] is not None: |
| 224 | if "ContextParallelTesterMixin" not in testers: |
| 225 | testers.append("ContextParallelTesterMixin") |
| 226 | |
| 227 | # Include AttentionTesterMixin if the model imports attention-related classes |
| 228 | if imports & ATTENTION_INDICATORS: |
| 229 | testers.append("AttentionTesterMixin") |
| 230 | |
| 231 | for tester, flag in OPTIONAL_TESTERS: |
| 232 | if flag in include_optional: |
| 233 | if tester == "ContextParallelAttentionBackendsTesterMixin": |
| 234 | if ( |
| 235 | "cp_attn" in include_optional |
| 236 | and "_cp_plan" in model_info["attributes"] |
| 237 | and model_info["attributes"]["_cp_plan"] is not None |
| 238 | ): |
| 239 | testers.append(tester) |
| 240 | elif tester not in testers: |
| 241 | testers.append(tester) |
| 242 | |
| 243 | return testers |
| 244 | |
| 245 | |
| 246 | def generate_config_class(model_info: dict, model_name: str) -> str: |
no outgoing calls
no test coverage detected
searching dependent graphs…