MCPcopy
hub / github.com/Wan-Video/Wan2.2 / generate

Function generate

generate.py:315–570  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

313
314
315def generate(args):
316 rank = int(os.getenv("RANK", 0))
317 world_size = int(os.getenv("WORLD_SIZE", 1))
318 local_rank = int(os.getenv("LOCAL_RANK", 0))
319 device = local_rank
320 _init_logging(rank)
321
322 if args.offload_model is None:
323 args.offload_model = False if world_size > 1 else True
324 logging.info(
325 f"offload_model is not specified, set to {args.offload_model}.")
326 if world_size > 1:
327 torch.cuda.set_device(local_rank)
328 dist.init_process_group(
329 backend="nccl",
330 init_method="env://",
331 rank=rank,
332 world_size=world_size)
333 else:
334 assert not (
335 args.t5_fsdp or args.dit_fsdp
336 ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
337 assert not (
338 args.ulysses_size > 1
339 ), f"sequence parallel are not supported in non-distributed environments."
340
341 if args.ulysses_size > 1:
342 assert args.ulysses_size == world_size, f"The number of ulysses_size should be equal to the world size."
343 init_distributed_group()
344
345 if args.use_prompt_extend:
346 if args.prompt_extend_method == "dashscope":
347 prompt_expander = DashScopePromptExpander(
348 model_name=args.prompt_extend_model,
349 task=args.task,
350 is_vl=args.image is not None)
351 elif args.prompt_extend_method == "local_qwen":
352 prompt_expander = QwenPromptExpander(
353 model_name=args.prompt_extend_model,
354 task=args.task,
355 is_vl=args.image is not None,
356 device=rank)
357 else:
358 raise NotImplementedError(
359 f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
360
361 cfg = WAN_CONFIGS[args.task]
362 if args.ulysses_size > 1:
363 assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."
364
365 logging.info(f"Generation job args: {args}")
366 logging.info(f"Generation model config: {cfg}")
367
368 if dist.is_initialized():
369 base_seed = [args.base_seed] if rank == 0 else [None]
370 dist.broadcast_object_list(base_seed, src=0)
371 args.base_seed = base_seed[0]
372

Callers 1

generate.pyFile · 0.85

Calls 12

generateMethod · 0.95
generateMethod · 0.95
generateMethod · 0.95
generateMethod · 0.95
generateMethod · 0.95
init_distributed_groupFunction · 0.90
QwenPromptExpanderClass · 0.90
save_videoFunction · 0.90
merge_video_audioFunction · 0.90
_init_loggingFunction · 0.85
set_deviceMethod · 0.80

Tested by

no test coverage detected