| 53 | |
| 54 | @torch.no_grad() |
| 55 | def test_function(model, save_path, file_name): |
| 56 | target_size = model.training_config["dataset"]["target_size"] |
| 57 | |
| 58 | condition_type = model.training_config["condition_type"] |
| 59 | test_list = [] |
| 60 | |
| 61 | # Generate two masks to test inpainting and outpainting. |
| 62 | mask1 = torch.ones((32, 32), dtype=bool) |
| 63 | mask1[8:24, 8:24] = False |
| 64 | mask2 = torch.logical_not(mask1) |
| 65 | |
| 66 | image = Image.open("assets/vase_hq.jpg").resize(target_size) |
| 67 | condition1 = Condition( |
| 68 | image, model.adapter_names[2], latent_mask=mask1, is_complement=True |
| 69 | ) |
| 70 | condition2 = Condition( |
| 71 | image, model.adapter_names[2], latent_mask=mask2, is_complement=True |
| 72 | ) |
| 73 | test_list.append((condition1, "A beautiful vase on a table.", mask2)) |
| 74 | test_list.append((condition2, "A beautiful vase on a table.", mask1)) |
| 75 | |
| 76 | os.makedirs(save_path, exist_ok=True) |
| 77 | for i, (condition, prompt, latent_mask) in enumerate(test_list): |
| 78 | generator = torch.Generator(device=model.device) |
| 79 | generator.manual_seed(42) |
| 80 | |
| 81 | res = generate( |
| 82 | model.flux_pipe, |
| 83 | prompt=prompt, |
| 84 | conditions=[condition], |
| 85 | height=target_size[0], |
| 86 | width=target_size[1], |
| 87 | generator=generator, |
| 88 | model_config=model.model_config, |
| 89 | kv_cache=model.model_config.get("independent_condition", False), |
| 90 | latent_mask=latent_mask, |
| 91 | ) |
| 92 | file_path = os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg") |
| 93 | res.images[0].save(file_path) |
| 94 | |
| 95 | |
| 96 | def main(): |