(
self,
config: VoxCPMConfig,
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAEV2,
lora_config: LoRAConfig = None,
device: str | None = None,
)
| 152 | |
| 153 | class VoxCPM2Model(nn.Module): |
| 154 | def __init__( |
| 155 | self, |
| 156 | config: VoxCPMConfig, |
| 157 | tokenizer: LlamaTokenizerFast, |
| 158 | audio_vae: AudioVAEV2, |
| 159 | lora_config: LoRAConfig = None, |
| 160 | device: str | None = None, |
| 161 | ): |
| 162 | super().__init__() |
| 163 | self.config = config |
| 164 | self.lora_config = lora_config |
| 165 | self.feat_dim = config.feat_dim |
| 166 | self.patch_size = config.patch_size |
| 167 | self.device = resolve_runtime_device(device, config.device) |
| 168 | self.config.device = self.device |
| 169 | resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype) |
| 170 | if resolved_dtype != self.config.dtype: |
| 171 | print( |
| 172 | f"[voxcpm2] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}", |
| 173 | file=sys.stderr, |
| 174 | ) |
| 175 | self.config.dtype = resolved_dtype |
| 176 | print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr) |
| 177 | |
| 178 | # Text-Semantic LM |
| 179 | self.base_lm = MiniCPMModel(config.lm_config) |
| 180 | self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype)) |
| 181 | |
| 182 | self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer) |
| 183 | self.audio_start_token = 101 |
| 184 | self.audio_end_token = 102 |
| 185 | self.ref_audio_start_token = 103 |
| 186 | self.ref_audio_end_token = 104 |
| 187 | |
| 188 | # Residual Acoustic LM |
| 189 | residual_lm_config = config.lm_config.model_copy(deep=True) |
| 190 | residual_lm_config.num_hidden_layers = config.residual_lm_num_layers |
| 191 | residual_lm_config.vocab_size = 0 |
| 192 | residual_lm_config.no_rope = config.residual_lm_no_rope |
| 193 | self.residual_lm = MiniCPMModel(residual_lm_config) |
| 194 | self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype)) |
| 195 | |
| 196 | # Local Encoder |
| 197 | encoder_config = config.lm_config.model_copy(deep=True) |
| 198 | encoder_config.hidden_size = config.encoder_config.hidden_dim |
| 199 | encoder_config.intermediate_size = config.encoder_config.ffn_dim |
| 200 | encoder_config.num_attention_heads = config.encoder_config.num_heads |
| 201 | encoder_config.num_hidden_layers = config.encoder_config.num_layers |
| 202 | encoder_config.kv_channels = config.encoder_config.kv_channels |
| 203 | encoder_config.vocab_size = 0 |
| 204 | self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim) |
| 205 | |
| 206 | # Local DiT |
| 207 | decoder_config = config.lm_config.model_copy(deep=True) |
| 208 | decoder_config.hidden_size = config.dit_config.hidden_dim |
| 209 | decoder_config.intermediate_size = config.dit_config.ffn_dim |
| 210 | decoder_config.num_attention_heads = config.dit_config.num_heads |
| 211 | decoder_config.num_hidden_layers = config.dit_config.num_layers |
nothing calls this directly
no test coverage detected