Sample the flux model. Either interactively (set `--loop`) or run for a single image. Args: height: height of the sample in pixels (should be a multiple of 16), None defaults to the size of the conditioning width: width of the sample in pixels (should be a m
(
name: str = "flux-dev-kontext",
aspect_ratio: str | None = None,
seed: int | None = None,
prompt: str = "replace the logo with the text 'Black Forest Labs'",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int = 30,
loop: bool = False,
guidance: float = 2.5,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
img_cond_path: str = "assets/cup.png",
trt: bool = False,
trt_transformer_precision: str = "bf16",
track_usage: bool = False,
)
| 148 | |
| 149 | @torch.inference_mode() |
| 150 | def main( |
| 151 | name: str = "flux-dev-kontext", |
| 152 | aspect_ratio: str | None = None, |
| 153 | seed: int | None = None, |
| 154 | prompt: str = "replace the logo with the text 'Black Forest Labs'", |
| 155 | device: str = "cuda" if torch.cuda.is_available() else "cpu", |
| 156 | num_steps: int = 30, |
| 157 | loop: bool = False, |
| 158 | guidance: float = 2.5, |
| 159 | offload: bool = False, |
| 160 | output_dir: str = "output", |
| 161 | add_sampling_metadata: bool = True, |
| 162 | img_cond_path: str = "assets/cup.png", |
| 163 | trt: bool = False, |
| 164 | trt_transformer_precision: str = "bf16", |
| 165 | track_usage: bool = False, |
| 166 | ): |
| 167 | """ |
| 168 | Sample the flux model. Either interactively (set `--loop`) or run for a |
| 169 | single image. |
| 170 | |
| 171 | Args: |
| 172 | height: height of the sample in pixels (should be a multiple of 16), None |
| 173 | defaults to the size of the conditioning |
| 174 | width: width of the sample in pixels (should be a multiple of 16), None |
| 175 | defaults to the size of the conditioning |
| 176 | seed: Set a seed for sampling |
| 177 | output_name: where to save the output image, `{idx}` will be replaced |
| 178 | by the index of the sample |
| 179 | prompt: Prompt used for sampling |
| 180 | device: Pytorch device |
| 181 | num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) |
| 182 | loop: start an interactive session and sample multiple times |
| 183 | guidance: guidance value used for guidance distillation |
| 184 | add_sampling_metadata: Add the prompt to the image Exif metadata |
| 185 | img_cond_path: path to conditioning image (jpeg/png/webp) |
| 186 | trt: use TensorRT backend for optimized inference |
| 187 | track_usage: track usage of the model for licensing purposes |
| 188 | """ |
| 189 | assert name == "flux-dev-kontext", f"Got unknown model name: {name}" |
| 190 | |
| 191 | torch_device = torch.device(device) |
| 192 | |
| 193 | output_name = os.path.join(output_dir, "img_{idx}.jpg") |
| 194 | if not os.path.exists(output_dir): |
| 195 | os.makedirs(output_dir) |
| 196 | idx = 0 |
| 197 | else: |
| 198 | fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] |
| 199 | if len(fns) > 0: |
| 200 | idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 |
| 201 | else: |
| 202 | idx = 0 |
| 203 | |
| 204 | if aspect_ratio is None: |
| 205 | width = None |
| 206 | height = None |
| 207 | else: |
nothing calls this directly
no test coverage detected
searching dependent graphs…