| 103 | |
| 104 | @torch.no_grad() |
| 105 | def test_function(model, save_path, file_name): |
| 106 | condition_size = model.training_config["dataset"]["condition_size"] |
| 107 | target_size = model.training_config["dataset"]["target_size"] |
| 108 | |
| 109 | # More details about position delta can be found in the documentation. |
| 110 | position_delta = [0, -condition_size[0] // 16] |
| 111 | |
| 112 | # Set adapters |
| 113 | adapter = model.adapter_names[2] |
| 114 | condition_type = model.training_config["condition_type"] |
| 115 | test_list = [] |
| 116 | |
| 117 | # Test case1 (in-distribution test case) |
| 118 | image = Image.open("assets/test_in.jpg") |
| 119 | image = image.resize(condition_size) |
| 120 | prompt = "Resting on the picnic table at a lakeside campsite, it's caught in the golden glow of early morning, with mist rising from the water and tall pines casting long shadows behind the scene." |
| 121 | condition = Condition(image, adapter, position_delta) |
| 122 | test_list.append((condition, prompt)) |
| 123 | |
| 124 | # Test case2 (out-of-distribution test case) |
| 125 | image = Image.open("assets/test_out.jpg") |
| 126 | image = image.resize(condition_size) |
| 127 | prompt = "In a bright room. It is placed on a table." |
| 128 | condition = Condition(image, adapter, position_delta) |
| 129 | test_list.append((condition, prompt)) |
| 130 | |
| 131 | # Generate images |
| 132 | os.makedirs(save_path, exist_ok=True) |
| 133 | for i, (condition, prompt) in enumerate(test_list): |
| 134 | generator = torch.Generator(device=model.device) |
| 135 | generator.manual_seed(42) |
| 136 | |
| 137 | res = generate( |
| 138 | model.flux_pipe, |
| 139 | prompt=prompt, |
| 140 | conditions=[condition], |
| 141 | height=target_size[1], |
| 142 | width=target_size[0], |
| 143 | generator=generator, |
| 144 | model_config=model.model_config, |
| 145 | kv_cache=model.model_config.get("independent_condition", False), |
| 146 | ) |
| 147 | file_path = os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg") |
| 148 | res.images[0].save(file_path) |
| 149 | |
| 150 | |
| 151 | def main(): |