MCPcopy Index your code
hub / github.com/hpcaitech/ColossalAI / ModelWrapper

Class ModelWrapper

colossalai/interface/model.py:102–127  ·  view source on GitHub ↗

A wrapper class to define the common interface used by booster. Args: module (nn.Module): The model to be wrapped.

Source from the content-addressed store, hash-verified

100
101
102class ModelWrapper(nn.Module):
103 """
104 A wrapper class to define the common interface used by booster.
105
106 Args:
107 module (nn.Module): The model to be wrapped.
108 """
109
110 def __init__(self, module: nn.Module) -> None:
111 super().__init__()
112 self.module = module
113
114 def unwrap(self, unwrap_peft: bool = True):
115 """
116 Unwrap the model to return the original model for checkpoint saving/loading.
117 """
118 if isinstance(self.module, ModelWrapper):
119 model = self.module.unwrap()
120 else:
121 model = self.module
122 if unwrap_peft and isinstance(model, PeftModel):
123 model = PeftUnwrapMixin(model)
124 return model
125
126 def forward(self, *args, **kwargs):
127 return self.module(*args, **kwargs)
128
129
130class AMPModelMixin:

Callers 2

init_modelMethod · 0.90
_init_modelMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…