MCPcopy
hub / github.com/Yuanshi9815/OminiControl / test_function

Function test_function

omini/train_flux/train_subject.py:105–148  ·  view source on GitHub ↗
(model, save_path, file_name)

Source from the content-addressed store, hash-verified

103
104@torch.no_grad()
105def test_function(model, save_path, file_name):
106 condition_size = model.training_config["dataset"]["condition_size"]
107 target_size = model.training_config["dataset"]["target_size"]
108
109 # More details about position delta can be found in the documentation.
110 position_delta = [0, -condition_size[0] // 16]
111
112 # Set adapters
113 adapter = model.adapter_names[2]
114 condition_type = model.training_config["condition_type"]
115 test_list = []
116
117 # Test case1 (in-distribution test case)
118 image = Image.open("assets/test_in.jpg")
119 image = image.resize(condition_size)
120 prompt = "Resting on the picnic table at a lakeside campsite, it's caught in the golden glow of early morning, with mist rising from the water and tall pines casting long shadows behind the scene."
121 condition = Condition(image, adapter, position_delta)
122 test_list.append((condition, prompt))
123
124 # Test case2 (out-of-distribution test case)
125 image = Image.open("assets/test_out.jpg")
126 image = image.resize(condition_size)
127 prompt = "In a bright room. It is placed on a table."
128 condition = Condition(image, adapter, position_delta)
129 test_list.append((condition, prompt))
130
131 # Generate images
132 os.makedirs(save_path, exist_ok=True)
133 for i, (condition, prompt) in enumerate(test_list):
134 generator = torch.Generator(device=model.device)
135 generator.manual_seed(42)
136
137 res = generate(
138 model.flux_pipe,
139 prompt=prompt,
140 conditions=[condition],
141 height=target_size[1],
142 width=target_size[0],
143 generator=generator,
144 model_config=model.model_config,
145 kv_cache=model.model_config.get("independent_condition", False),
146 )
147 file_path = os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg")
148 res.images[0].save(file_path)
149
150
151def main():

Callers

nothing calls this directly

Calls 2

ConditionClass · 0.85
generateFunction · 0.85

Tested by

no test coverage detected