MCPcopy Index your code
hub / github.com/hpcaitech/ColossalAI / DiffusionEngine

Class DiffusionEngine

colossalai/inference/core/diffusion_engine.py:27–200  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

25
26
27class DiffusionEngine(BaseEngine):
28 def __init__(
29 self,
30 model_or_path: DiffusionPipeline | str,
31 inference_config: InferenceConfig = None,
32 verbose: bool = False,
33 model_policy: Policy | type[Policy] = None,
34 ) -> None:
35 self.inference_config = inference_config
36 self.dtype = inference_config.dtype
37 self.high_precision = inference_config.high_precision
38
39 self.verbose = verbose
40 self.logger = get_dist_logger(__name__)
41 self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
42
43 self.model_type = get_model_type(model_or_path=model_or_path)
44
45 self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
46
47 self.request_handler = NaiveRequestHandler()
48
49 self.counter = count()
50
51 self._verify_args()
52
53 def _verify_args(self) -> None:
54 assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe"
55
56 def init_model(
57 self,
58 model_or_path: Union[str, nn.Module, DiffusionPipeline],
59 model_policy: Union[Policy, Type[Policy]] = None,
60 model_shard_infer_config: ModelShardInferenceConfig = None,
61 ):
62 """
63 Shard model or/and Load weight
64
65 Args:
66 model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
67 model_policy (Policy): the policy to replace the model.
68 model_inference_config: the configuration for modeling initialization when inference.
69 model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
70 """
71 if isinstance(model_or_path, str):
72 model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype)
73 policy_map_key = model.__class__.__name__
74 model = DiffusionPipe(model)
75 elif isinstance(model_or_path, DiffusionPipeline):
76 policy_map_key = model_or_path.__class__.__name__
77 model = DiffusionPipe(model_or_path)
78 else:
79 self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!")
80
81 torch.cuda.empty_cache()
82 init_gpu_memory = torch.cuda.mem_get_info()[0]
83
84 self.device = get_accelerator().get_current_device()

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…