(args)
| 264 | |
| 265 | |
| 266 | def generate(args): |
| 267 | rank = int(os.getenv("RANK", 0)) |
| 268 | world_size = int(os.getenv("WORLD_SIZE", 1)) |
| 269 | local_rank = int(os.getenv("LOCAL_RANK", 0)) |
| 270 | device = local_rank |
| 271 | _init_logging(rank) |
| 272 | |
| 273 | if args.offload_model is None: |
| 274 | args.offload_model = False if world_size > 1 else True |
| 275 | logging.info( |
| 276 | f"offload_model is not specified, set to {args.offload_model}.") |
| 277 | if world_size > 1: |
| 278 | torch.cuda.set_device(local_rank) |
| 279 | dist.init_process_group( |
| 280 | backend="nccl", |
| 281 | init_method="env://", |
| 282 | rank=rank, |
| 283 | world_size=world_size) |
| 284 | else: |
| 285 | assert not ( |
| 286 | args.t5_fsdp or args.dit_fsdp |
| 287 | ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." |
| 288 | assert not ( |
| 289 | args.ulysses_size > 1 or args.ring_size > 1 |
| 290 | ), f"context parallel are not supported in non-distributed environments." |
| 291 | |
| 292 | if args.ulysses_size > 1 or args.ring_size > 1: |
| 293 | assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." |
| 294 | from xfuser.core.distributed import ( |
| 295 | init_distributed_environment, |
| 296 | initialize_model_parallel, |
| 297 | ) |
| 298 | init_distributed_environment( |
| 299 | rank=dist.get_rank(), world_size=dist.get_world_size()) |
| 300 | |
| 301 | initialize_model_parallel( |
| 302 | sequence_parallel_degree=dist.get_world_size(), |
| 303 | ring_degree=args.ring_size, |
| 304 | ulysses_degree=args.ulysses_size, |
| 305 | ) |
| 306 | |
| 307 | if args.use_prompt_extend: |
| 308 | if args.prompt_extend_method == "dashscope": |
| 309 | prompt_expander = DashScopePromptExpander( |
| 310 | model_name=args.prompt_extend_model, |
| 311 | is_vl="i2v" in args.task or "flf2v" in args.task) |
| 312 | elif args.prompt_extend_method == "local_qwen": |
| 313 | prompt_expander = QwenPromptExpander( |
| 314 | model_name=args.prompt_extend_model, |
| 315 | is_vl="i2v" in args.task, |
| 316 | device=rank) |
| 317 | else: |
| 318 | raise NotImplementedError( |
| 319 | f"Unsupport prompt_extend_method: {args.prompt_extend_method}") |
| 320 | |
| 321 | cfg = WAN_CONFIGS[args.task] |
| 322 | if args.ulysses_size > 1: |
| 323 | assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`." |
no test coverage detected