MCPcopy
hub / github.com/z-lab/dflash / DFlashDraftModel

Class DFlashDraftModel

dflash/model_mlx.py:132–198  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

130
131
132class DFlashDraftModel(nn.Module):
133 def __init__(self, config: DFlashConfig):
134 super().__init__()
135 self.config = config
136 if not self.config.layer_types:
137 self.config.layer_types = ("full_attention",) * self.config.num_hidden_layers
138 concat_dim = len(config.target_layer_ids) * config.hidden_size
139 self.fc = nn.Linear(concat_dim, config.hidden_size, bias=False)
140 self.hidden_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
141 self.layers = [DFlashDecoderLayer(config, i) for i in range(config.num_hidden_layers)]
142 self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
143 self.rope = _build_rope(
144 config.head_dim,
145 config.rope_theta,
146 config.max_position_embeddings,
147 config.rope_scaling,
148 )
149 self.embed_tokens = None
150 self.lm_head = None
151 self.embed_scale = 1.0
152
153 def bind(self, target_model):
154 if hasattr(target_model, "embed_tokens"):
155 inner = target_model
156 elif hasattr(target_model, "model") and hasattr(target_model.model, "embed_tokens"):
157 inner = target_model.model
158 elif (hasattr(target_model, "language_model") and
159 hasattr(target_model.language_model, "model") and
160 hasattr(target_model.language_model.model, "embed_tokens")):
161 inner = target_model.language_model.model
162 else:
163 raise AttributeError(f"Cannot find embed_tokens in {type(target_model).__name__}")
164 self.embed_tokens = inner.embed_tokens
165 self.embed_scale = getattr(self.embed_tokens, "embed_scale", getattr(inner, "embed_scale", 1.0))
166 lm = getattr(target_model, "language_model", target_model)
167 self.lm_head = getattr(target_model, "lm_head", None) or getattr(lm, "lm_head", None) or self.embed_tokens.as_linear
168 return self
169
170 def make_cache(self):
171 caches = []
172 for layer_type in self.config.layer_types:
173 if layer_type == "sliding_attention":
174 if self.config.sliding_window is None:
175 raise ValueError("Draft config must define sliding_window for sliding_attention layers.")
176 caches.append(RotatingKVCache(max_size=self.config.sliding_window - 1, keep=0))
177 else:
178 caches.append(KVCache())
179 return caches
180
181 def __call__(
182 self,
183 inputs,
184 target_hidden,
185 cache,
186 logits_start: int = 0,
187 ):
188 h = self.embed_tokens(inputs) * self.embed_scale
189 h_ctx = self.hidden_norm(self.fc(target_hidden))

Callers 1

load_draftFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected