(model, save_path, file_name)
| 106 | |
| 107 | @torch.no_grad() |
| 108 | def test_function(model, save_path, file_name): |
| 109 | condition_size = model.training_config["dataset"]["condition_size"] |
| 110 | target_size = model.training_config["dataset"]["target_size"] |
| 111 | |
| 112 | position_delta = model.training_config["dataset"].get("position_delta", [0, 0]) |
| 113 | position_scale = model.training_config["dataset"].get("position_scale", 1.0) |
| 114 | |
| 115 | adapter = model.adapter_names[2] |
| 116 | condition_type = model.training_config["condition_type"] |
| 117 | test_list = [] |
| 118 | |
| 119 | if condition_type in ["canny", "coloring", "deblurring", "depth"]: |
| 120 | image = Image.open("assets/vase_hq.jpg") |
| 121 | image = image.resize(condition_size) |
| 122 | condition_img = convert_to_condition(condition_type, image, 5) |
| 123 | condition = Condition(condition_img, adapter, position_delta, position_scale) |
| 124 | test_list.append((condition, "A beautiful vase on a table.")) |
| 125 | elif condition_type == "depth_pred": |
| 126 | image = Image.open("assets/vase_hq.jpg") |
| 127 | image = image.resize(condition_size) |
| 128 | condition = Condition(image, adapter, position_delta, position_scale) |
| 129 | test_list.append((condition, "A beautiful vase on a table.")) |
| 130 | elif condition_type == "fill": |
| 131 | condition_img = ( |
| 132 | Image.open("./assets/vase_hq.jpg").resize(condition_size).convert("RGB") |
| 133 | ) |
| 134 | mask = Image.new("L", condition_img.size, 0) |
| 135 | draw = ImageDraw.Draw(mask) |
| 136 | a = condition_img.size[0] // 4 |
| 137 | b = a * 3 |
| 138 | draw.rectangle([a, a, b, b], fill=255) |
| 139 | condition_img = Image.composite( |
| 140 | condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask |
| 141 | ) |
| 142 | condition = Condition(condition, adapter, position_delta, position_scale) |
| 143 | test_list.append((condition, "A beautiful vase on a table.")) |
| 144 | elif condition_type == "super_resolution": |
| 145 | image = Image.open("assets/vase_hq.jpg") |
| 146 | image = image.resize(condition_size) |
| 147 | condition = Condition(image, adapter, position_delta, position_scale) |
| 148 | test_list.append((condition, "A beautiful vase on a table.")) |
| 149 | else: |
| 150 | raise NotImplementedError |
| 151 | os.makedirs(save_path, exist_ok=True) |
| 152 | for i, (condition, prompt) in enumerate(test_list): |
| 153 | generator = torch.Generator(device=model.device) |
| 154 | generator.manual_seed(42) |
| 155 | |
| 156 | res = generate( |
| 157 | model.flux_pipe, |
| 158 | prompt=prompt, |
| 159 | conditions=[condition], |
| 160 | height=target_size[1], |
| 161 | width=target_size[0], |
| 162 | generator=generator, |
| 163 | model_config=model.model_config, |
| 164 | kv_cache=model.model_config.get("independent_condition", False), |
| 165 | ) |
nothing calls this directly
no test coverage detected