MCPcopy
hub / github.com/Wan-Video/Wan2.1 / generate

Function generate

generate.py:266–582  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

264
265
266def generate(args):
267 rank = int(os.getenv("RANK", 0))
268 world_size = int(os.getenv("WORLD_SIZE", 1))
269 local_rank = int(os.getenv("LOCAL_RANK", 0))
270 device = local_rank
271 _init_logging(rank)
272
273 if args.offload_model is None:
274 args.offload_model = False if world_size > 1 else True
275 logging.info(
276 f"offload_model is not specified, set to {args.offload_model}.")
277 if world_size > 1:
278 torch.cuda.set_device(local_rank)
279 dist.init_process_group(
280 backend="nccl",
281 init_method="env://",
282 rank=rank,
283 world_size=world_size)
284 else:
285 assert not (
286 args.t5_fsdp or args.dit_fsdp
287 ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
288 assert not (
289 args.ulysses_size > 1 or args.ring_size > 1
290 ), f"context parallel are not supported in non-distributed environments."
291
292 if args.ulysses_size > 1 or args.ring_size > 1:
293 assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
294 from xfuser.core.distributed import (
295 init_distributed_environment,
296 initialize_model_parallel,
297 )
298 init_distributed_environment(
299 rank=dist.get_rank(), world_size=dist.get_world_size())
300
301 initialize_model_parallel(
302 sequence_parallel_degree=dist.get_world_size(),
303 ring_degree=args.ring_size,
304 ulysses_degree=args.ulysses_size,
305 )
306
307 if args.use_prompt_extend:
308 if args.prompt_extend_method == "dashscope":
309 prompt_expander = DashScopePromptExpander(
310 model_name=args.prompt_extend_model,
311 is_vl="i2v" in args.task or "flf2v" in args.task)
312 elif args.prompt_extend_method == "local_qwen":
313 prompt_expander = QwenPromptExpander(
314 model_name=args.prompt_extend_model,
315 is_vl="i2v" in args.task,
316 device=rank)
317 else:
318 raise NotImplementedError(
319 f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
320
321 cfg = WAN_CONFIGS[args.task]
322 if args.ulysses_size > 1:
323 assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."

Callers 1

generate.pyFile · 0.85

Calls 12

generateMethod · 0.95
generateMethod · 0.95
generateMethod · 0.95
prepare_sourceMethod · 0.95
generateMethod · 0.95
QwenPromptExpanderClass · 0.90
cache_imageFunction · 0.90
cache_videoFunction · 0.90
_init_loggingFunction · 0.85
getMethod · 0.80
forwardMethod · 0.45

Tested by

no test coverage detected