(args)
| 21 | |
| 22 | |
| 23 | def infer(args): |
| 24 | # ============================== |
| 25 | # Launch colossalai, setup distributed environment |
| 26 | # ============================== |
| 27 | colossalai.launch_from_torch() |
| 28 | coordinator = DistCoordinator() |
| 29 | |
| 30 | # ============================== |
| 31 | # Load model and tokenizer |
| 32 | # ============================== |
| 33 | model_path_or_name = args.model |
| 34 | model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None)) |
| 35 | |
| 36 | # ============================== |
| 37 | # Initialize InferenceEngine |
| 38 | # ============================== |
| 39 | coordinator.print_on_master(f"Initializing Inference Engine...") |
| 40 | inference_config = InferenceConfig( |
| 41 | dtype=args.dtype, |
| 42 | max_batch_size=args.max_batch_size, |
| 43 | tp_size=args.tp_size, |
| 44 | use_cuda_kernel=args.use_cuda_kernel, |
| 45 | patched_parallelism_size=dist.get_world_size(), |
| 46 | ) |
| 47 | engine = InferenceEngine(model, inference_config=inference_config, verbose=True) |
| 48 | |
| 49 | # ============================== |
| 50 | # Generation |
| 51 | # ============================== |
| 52 | coordinator.print_on_master(f"Generating...") |
| 53 | out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0] |
| 54 | if dist.get_rank() == 0: |
| 55 | out.save(f"cat_parallel_size{dist.get_world_size()}.jpg") |
| 56 | coordinator.print_on_master(out) |
| 57 | |
| 58 | |
| 59 | # colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH |
no test coverage detected
searching dependent graphs…