MCPcopy
hub / github.com/huggingface/diffusers / generate_config_class

Function generate_config_class

utils/generate_model_tests.py:246–310  ·  view source on GitHub ↗
(model_info: dict, model_name: str)

Source from the content-addressed store, hash-verified

244
245
246def generate_config_class(model_info: dict, model_name: str) -> str:
247 class_name = f"{model_name}TesterConfig"
248 model_class = model_info["name"]
249 forward_params = model_info.get("forward_params", [])
250 init_params = model_info.get("init_params", [])
251
252 lines = [
253 f"class {class_name}:",
254 " @property",
255 " def model_class(self):",
256 f" return {model_class}",
257 "",
258 " @property",
259 " def pretrained_model_name_or_path(self):",
260 ' return "" # TODO: Set Hub repository ID',
261 "",
262 " @property",
263 " def pretrained_model_kwargs(self):",
264 ' return {"subfolder": "transformer"}',
265 "",
266 " @property",
267 " def generator(self):",
268 ' return torch.Generator("cpu").manual_seed(0)',
269 "",
270 " def get_init_dict(self) -> dict[str, int | list[int]]:",
271 ]
272
273 if init_params:
274 lines.append(" # __init__ parameters:")
275 for param in init_params:
276 type_str = f": {param['type']}" if param["type"] else ""
277 default_str = f" = {param['default']}" if param["default"] is not None else ""
278 lines.append(f" # {param['name']}{type_str}{default_str}")
279
280 lines.extend(
281 [
282 " return {}",
283 "",
284 " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
285 ]
286 )
287
288 if forward_params:
289 lines.append(" # forward() parameters:")
290 for param in forward_params:
291 type_str = f": {param['type']}" if param["type"] else ""
292 default_str = f" = {param['default']}" if param["default"] is not None else ""
293 lines.append(f" # {param['name']}{type_str}{default_str}")
294
295 lines.extend(
296 [
297 " # TODO: Fill in dummy inputs",
298 " return {}",
299 "",
300 " @property",
301 " def input_shape(self) -> tuple[int, ...]:",
302 " return (1, 1)",
303 "",

Callers 1

generate_test_fileFunction · 0.85

Calls 1

getMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…