| 151 | self.embed_scale = 1.0 |
| 152 | |
| 153 | def bind(self, target_model): |
| 154 | if hasattr(target_model, "embed_tokens"): |
| 155 | inner = target_model |
| 156 | elif hasattr(target_model, "model") and hasattr(target_model.model, "embed_tokens"): |
| 157 | inner = target_model.model |
| 158 | elif (hasattr(target_model, "language_model") and |
| 159 | hasattr(target_model.language_model, "model") and |
| 160 | hasattr(target_model.language_model.model, "embed_tokens")): |
| 161 | inner = target_model.language_model.model |
| 162 | else: |
| 163 | raise AttributeError(f"Cannot find embed_tokens in {type(target_model).__name__}") |
| 164 | self.embed_tokens = inner.embed_tokens |
| 165 | self.embed_scale = getattr(self.embed_tokens, "embed_scale", getattr(inner, "embed_scale", 1.0)) |
| 166 | lm = getattr(target_model, "language_model", target_model) |
| 167 | self.lm_head = getattr(target_model, "lm_head", None) or getattr(lm, "lm_head", None) or self.embed_tokens.as_linear |
| 168 | return self |
| 169 | |
| 170 | def make_cache(self): |
| 171 | caches = [] |