MCPcopy
hub / github.com/hpcaitech/ColossalAI / infer

Function infer

examples/inference/stable_diffusion/sd3_generation.py:23–56  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

21
22
23def infer(args):
24 # ==============================
25 # Launch colossalai, setup distributed environment
26 # ==============================
27 colossalai.launch_from_torch()
28 coordinator = DistCoordinator()
29
30 # ==============================
31 # Load model and tokenizer
32 # ==============================
33 model_path_or_name = args.model
34 model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
35
36 # ==============================
37 # Initialize InferenceEngine
38 # ==============================
39 coordinator.print_on_master(f"Initializing Inference Engine...")
40 inference_config = InferenceConfig(
41 dtype=args.dtype,
42 max_batch_size=args.max_batch_size,
43 tp_size=args.tp_size,
44 use_cuda_kernel=args.use_cuda_kernel,
45 patched_parallelism_size=dist.get_world_size(),
46 )
47 engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
48
49 # ==============================
50 # Generation
51 # ==============================
52 coordinator.print_on_master(f"Generating...")
53 out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
54 if dist.get_rank() == 0:
55 out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
56 coordinator.print_on_master(out)
57
58
59# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH

Callers 1

sd3_generation.pyFile · 0.70

Calls 11

print_on_masterMethod · 0.95
generateMethod · 0.95
DistCoordinatorClass · 0.90
InferenceConfigClass · 0.90
InferenceEngineClass · 0.90
get_rankMethod · 0.80
from_pretrainedMethod · 0.45
getMethod · 0.45
get_world_sizeMethod · 0.45
saveMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…