Sample the flux model. Either interactively (set `--loop`) or run for a single image. This demo assumes that the conditioning image and mask have the same shape and that height and width are divisible by 32. Args: seed: Set a seed for sampling output_name: where to
(
seed: int | None = None,
prompt: str = "a white paper cup",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int = 50,
loop: bool = False,
guidance: float = 30.0,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
img_cond_path: str = "assets/cup.png",
img_mask_path: str = "assets/cup_mask.png",
track_usage: bool = False,
)
| 173 | |
| 174 | @torch.inference_mode() |
| 175 | def main( |
| 176 | seed: int | None = None, |
| 177 | prompt: str = "a white paper cup", |
| 178 | device: str = "cuda" if torch.cuda.is_available() else "cpu", |
| 179 | num_steps: int = 50, |
| 180 | loop: bool = False, |
| 181 | guidance: float = 30.0, |
| 182 | offload: bool = False, |
| 183 | output_dir: str = "output", |
| 184 | add_sampling_metadata: bool = True, |
| 185 | img_cond_path: str = "assets/cup.png", |
| 186 | img_mask_path: str = "assets/cup_mask.png", |
| 187 | track_usage: bool = False, |
| 188 | ): |
| 189 | """ |
| 190 | Sample the flux model. Either interactively (set `--loop`) or run for a |
| 191 | single image. This demo assumes that the conditioning image and mask have |
| 192 | the same shape and that height and width are divisible by 32. |
| 193 | |
| 194 | Args: |
| 195 | seed: Set a seed for sampling |
| 196 | output_name: where to save the output image, `{idx}` will be replaced |
| 197 | by the index of the sample |
| 198 | prompt: Prompt used for sampling |
| 199 | device: Pytorch device |
| 200 | num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) |
| 201 | loop: start an interactive session and sample multiple times |
| 202 | guidance: guidance value used for guidance distillation |
| 203 | add_sampling_metadata: Add the prompt to the image Exif metadata |
| 204 | img_cond_path: path to conditioning image (jpeg/png/webp) |
| 205 | img_mask_path: path to conditioning mask (jpeg/png/webp) |
| 206 | track_usage: track usage of the model for licensing purposes |
| 207 | """ |
| 208 | nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) |
| 209 | |
| 210 | name = "flux-dev-fill" |
| 211 | if name not in configs: |
| 212 | available = ", ".join(configs.keys()) |
| 213 | raise ValueError(f"Got unknown model name: {name}, chose from {available}") |
| 214 | |
| 215 | torch_device = torch.device(device) |
| 216 | |
| 217 | output_name = os.path.join(output_dir, "img_{idx}.jpg") |
| 218 | if not os.path.exists(output_dir): |
| 219 | os.makedirs(output_dir) |
| 220 | idx = 0 |
| 221 | else: |
| 222 | fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] |
| 223 | if len(fns) > 0: |
| 224 | idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 |
| 225 | else: |
| 226 | idx = 0 |
| 227 | |
| 228 | # init all components |
| 229 | t5 = load_t5(torch_device, max_length=128) |
| 230 | clip = load_clip(torch_device) |
| 231 | model = load_flow_model(name, device="cpu" if offload else torch_device) |
| 232 | ae = load_ae(name, device="cpu" if offload else torch_device) |
nothing calls this directly
no test coverage detected
searching dependent graphs…