MCPcopy
hub / github.com/Yuanshi9815/OminiControl / test_function

Function test_function

omini/train_flux/train_multi_condition.py:57–114  ·  view source on GitHub ↗
(model, save_path, file_name)

Source from the content-addressed store, hash-verified

55
56@torch.no_grad()
57def test_function(model, save_path, file_name):
58 condition_size = model.training_config["dataset"]["condition_size"]
59 target_size = model.training_config["dataset"]["target_size"]
60
61 position_delta = model.training_config["dataset"].get("position_delta", [0, 0])
62 position_scale = model.training_config["dataset"].get("position_scale", 1.0)
63
64 condition_type = model.training_config["condition_type"]
65 test_list = []
66
67 condition_list = []
68 for i, c_type in enumerate(condition_type):
69 if c_type in ["canny", "coloring", "deblurring", "depth"]:
70 image = Image.open("assets/vase_hq.jpg")
71 image = image.resize(condition_size)
72 condition_img = convert_to_condition(c_type, image, 5)
73 elif c_type == "fill":
74 condition_img = image.resize(condition_size).convert("RGB")
75 w, h = image.size
76 x1, x2 = sorted([random.randint(0, w), random.randint(0, w)])
77 y1, y2 = sorted([random.randint(0, h), random.randint(0, h)])
78 mask = Image.new("L", image.size, 0)
79 draw = ImageDraw.Draw(mask)
80 draw.rectangle([x1, y1, x2, y2], fill=255)
81 if random.random() > 0.5:
82 mask = Image.eval(mask, lambda a: 255 - a)
83 condition_img = Image.composite(
84 image, Image.new("RGB", image.size, (0, 0, 0)), mask
85 )
86 else:
87 raise NotImplementedError
88 condition = Condition(
89 condition_img,
90 model.adapter_names[i + 2],
91 position_delta,
92 position_scale,
93 )
94 condition_list.append(condition)
95 test_list.append((condition_list, "A beautiful vase on a table."))
96 os.makedirs(save_path, exist_ok=True)
97 for i, (condition, prompt) in enumerate(test_list):
98 generator = torch.Generator(device=model.device)
99 generator.manual_seed(42)
100
101 res = generate(
102 model.flux_pipe,
103 prompt=prompt,
104 conditions=condition_list,
105 height=target_size[0],
106 width=target_size[1],
107 generator=generator,
108 model_config=model.model_config,
109 kv_cache=model.model_config.get("independent_condition", False),
110 )
111 file_path = os.path.join(
112 save_path, f"{file_name}_{'|'.join(condition_type)}_{i}.jpg"
113 )
114 res.images[0].save(file_path)

Callers

nothing calls this directly

Calls 3

convert_to_conditionFunction · 0.85
ConditionClass · 0.85
generateFunction · 0.85

Tested by

no test coverage detected