(args)
| 313 | |
| 314 | |
| 315 | def generate(args): |
| 316 | rank = int(os.getenv("RANK", 0)) |
| 317 | world_size = int(os.getenv("WORLD_SIZE", 1)) |
| 318 | local_rank = int(os.getenv("LOCAL_RANK", 0)) |
| 319 | device = local_rank |
| 320 | _init_logging(rank) |
| 321 | |
| 322 | if args.offload_model is None: |
| 323 | args.offload_model = False if world_size > 1 else True |
| 324 | logging.info( |
| 325 | f"offload_model is not specified, set to {args.offload_model}.") |
| 326 | if world_size > 1: |
| 327 | torch.cuda.set_device(local_rank) |
| 328 | dist.init_process_group( |
| 329 | backend="nccl", |
| 330 | init_method="env://", |
| 331 | rank=rank, |
| 332 | world_size=world_size) |
| 333 | else: |
| 334 | assert not ( |
| 335 | args.t5_fsdp or args.dit_fsdp |
| 336 | ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." |
| 337 | assert not ( |
| 338 | args.ulysses_size > 1 |
| 339 | ), f"sequence parallel are not supported in non-distributed environments." |
| 340 | |
| 341 | if args.ulysses_size > 1: |
| 342 | assert args.ulysses_size == world_size, f"The number of ulysses_size should be equal to the world size." |
| 343 | init_distributed_group() |
| 344 | |
| 345 | if args.use_prompt_extend: |
| 346 | if args.prompt_extend_method == "dashscope": |
| 347 | prompt_expander = DashScopePromptExpander( |
| 348 | model_name=args.prompt_extend_model, |
| 349 | task=args.task, |
| 350 | is_vl=args.image is not None) |
| 351 | elif args.prompt_extend_method == "local_qwen": |
| 352 | prompt_expander = QwenPromptExpander( |
| 353 | model_name=args.prompt_extend_model, |
| 354 | task=args.task, |
| 355 | is_vl=args.image is not None, |
| 356 | device=rank) |
| 357 | else: |
| 358 | raise NotImplementedError( |
| 359 | f"Unsupport prompt_extend_method: {args.prompt_extend_method}") |
| 360 | |
| 361 | cfg = WAN_CONFIGS[args.task] |
| 362 | if args.ulysses_size > 1: |
| 363 | assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`." |
| 364 | |
| 365 | logging.info(f"Generation job args: {args}") |
| 366 | logging.info(f"Generation model config: {cfg}") |
| 367 | |
| 368 | if dist.is_initialized(): |
| 369 | base_seed = [args.base_seed] if rank == 0 else [None] |
| 370 | dist.broadcast_object_list(base_seed, src=0) |
| 371 | args.base_seed = base_seed[0] |
| 372 |
no test coverage detected