(model, save_path, file_name)
| 55 | |
| 56 | @torch.no_grad() |
| 57 | def test_function(model, save_path, file_name): |
| 58 | condition_size = model.training_config["dataset"]["condition_size"] |
| 59 | target_size = model.training_config["dataset"]["target_size"] |
| 60 | |
| 61 | position_delta = model.training_config["dataset"].get("position_delta", [0, 0]) |
| 62 | position_scale = model.training_config["dataset"].get("position_scale", 1.0) |
| 63 | |
| 64 | condition_type = model.training_config["condition_type"] |
| 65 | test_list = [] |
| 66 | |
| 67 | condition_list = [] |
| 68 | for i, c_type in enumerate(condition_type): |
| 69 | if c_type in ["canny", "coloring", "deblurring", "depth"]: |
| 70 | image = Image.open("assets/vase_hq.jpg") |
| 71 | image = image.resize(condition_size) |
| 72 | condition_img = convert_to_condition(c_type, image, 5) |
| 73 | elif c_type == "fill": |
| 74 | condition_img = image.resize(condition_size).convert("RGB") |
| 75 | w, h = image.size |
| 76 | x1, x2 = sorted([random.randint(0, w), random.randint(0, w)]) |
| 77 | y1, y2 = sorted([random.randint(0, h), random.randint(0, h)]) |
| 78 | mask = Image.new("L", image.size, 0) |
| 79 | draw = ImageDraw.Draw(mask) |
| 80 | draw.rectangle([x1, y1, x2, y2], fill=255) |
| 81 | if random.random() > 0.5: |
| 82 | mask = Image.eval(mask, lambda a: 255 - a) |
| 83 | condition_img = Image.composite( |
| 84 | image, Image.new("RGB", image.size, (0, 0, 0)), mask |
| 85 | ) |
| 86 | else: |
| 87 | raise NotImplementedError |
| 88 | condition = Condition( |
| 89 | condition_img, |
| 90 | model.adapter_names[i + 2], |
| 91 | position_delta, |
| 92 | position_scale, |
| 93 | ) |
| 94 | condition_list.append(condition) |
| 95 | test_list.append((condition_list, "A beautiful vase on a table.")) |
| 96 | os.makedirs(save_path, exist_ok=True) |
| 97 | for i, (condition, prompt) in enumerate(test_list): |
| 98 | generator = torch.Generator(device=model.device) |
| 99 | generator.manual_seed(42) |
| 100 | |
| 101 | res = generate( |
| 102 | model.flux_pipe, |
| 103 | prompt=prompt, |
| 104 | conditions=condition_list, |
| 105 | height=target_size[0], |
| 106 | width=target_size[1], |
| 107 | generator=generator, |
| 108 | model_config=model.model_config, |
| 109 | kv_cache=model.model_config.get("independent_condition", False), |
| 110 | ) |
| 111 | file_path = os.path.join( |
| 112 | save_path, f"{file_name}_{'|'.join(condition_type)}_{i}.jpg" |
| 113 | ) |
| 114 | res.images[0].save(file_path) |
nothing calls this directly
no test coverage detected