MCPcopy
hub / github.com/Yuanshi9815/OminiControl / test_function

Function test_function

omini/train_flux/train_spatial_alignment.py:108–167  ·  view source on GitHub ↗
(model, save_path, file_name)

Source from the content-addressed store, hash-verified

106
107@torch.no_grad()
108def test_function(model, save_path, file_name):
109 condition_size = model.training_config["dataset"]["condition_size"]
110 target_size = model.training_config["dataset"]["target_size"]
111
112 position_delta = model.training_config["dataset"].get("position_delta", [0, 0])
113 position_scale = model.training_config["dataset"].get("position_scale", 1.0)
114
115 adapter = model.adapter_names[2]
116 condition_type = model.training_config["condition_type"]
117 test_list = []
118
119 if condition_type in ["canny", "coloring", "deblurring", "depth"]:
120 image = Image.open("assets/vase_hq.jpg")
121 image = image.resize(condition_size)
122 condition_img = convert_to_condition(condition_type, image, 5)
123 condition = Condition(condition_img, adapter, position_delta, position_scale)
124 test_list.append((condition, "A beautiful vase on a table."))
125 elif condition_type == "depth_pred":
126 image = Image.open("assets/vase_hq.jpg")
127 image = image.resize(condition_size)
128 condition = Condition(image, adapter, position_delta, position_scale)
129 test_list.append((condition, "A beautiful vase on a table."))
130 elif condition_type == "fill":
131 condition_img = (
132 Image.open("./assets/vase_hq.jpg").resize(condition_size).convert("RGB")
133 )
134 mask = Image.new("L", condition_img.size, 0)
135 draw = ImageDraw.Draw(mask)
136 a = condition_img.size[0] // 4
137 b = a * 3
138 draw.rectangle([a, a, b, b], fill=255)
139 condition_img = Image.composite(
140 condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask
141 )
142 condition = Condition(condition, adapter, position_delta, position_scale)
143 test_list.append((condition, "A beautiful vase on a table."))
144 elif condition_type == "super_resolution":
145 image = Image.open("assets/vase_hq.jpg")
146 image = image.resize(condition_size)
147 condition = Condition(image, adapter, position_delta, position_scale)
148 test_list.append((condition, "A beautiful vase on a table."))
149 else:
150 raise NotImplementedError
151 os.makedirs(save_path, exist_ok=True)
152 for i, (condition, prompt) in enumerate(test_list):
153 generator = torch.Generator(device=model.device)
154 generator.manual_seed(42)
155
156 res = generate(
157 model.flux_pipe,
158 prompt=prompt,
159 conditions=[condition],
160 height=target_size[1],
161 width=target_size[0],
162 generator=generator,
163 model_config=model.model_config,
164 kv_cache=model.model_config.get("independent_condition", False),
165 )

Callers

nothing calls this directly

Calls 3

convert_to_conditionFunction · 0.85
ConditionClass · 0.85
generateFunction · 0.85

Tested by

no test coverage detected