(
self,
config: VoxCPMConfig,
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAE,
lora_config: LoRAConfig = None,
device: str | None = None,
)
| 110 | |
| 111 | class VoxCPMModel(nn.Module): |
| 112 | def __init__( |
| 113 | self, |
| 114 | config: VoxCPMConfig, |
| 115 | tokenizer: LlamaTokenizerFast, |
| 116 | audio_vae: AudioVAE, |
| 117 | lora_config: LoRAConfig = None, |
| 118 | device: str | None = None, |
| 119 | ): |
| 120 | super().__init__() |
| 121 | self.config = config |
| 122 | self.lora_config = lora_config |
| 123 | self.feat_dim = config.feat_dim |
| 124 | self.patch_size = config.patch_size |
| 125 | self.device = resolve_runtime_device(device, config.device) |
| 126 | self.config.device = self.device |
| 127 | resolved_dtype = pick_runtime_dtype(self.device, self.config.dtype) |
| 128 | if resolved_dtype != self.config.dtype: |
| 129 | print( |
| 130 | f"[voxcpm] adjusted dtype {self.config.dtype} -> {resolved_dtype} for device {self.device}", |
| 131 | file=sys.stderr, |
| 132 | ) |
| 133 | self.config.dtype = resolved_dtype |
| 134 | print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr) |
| 135 | |
| 136 | # Text-Semantic LM |
| 137 | self.base_lm = MiniCPMModel(config.lm_config) |
| 138 | self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype)) |
| 139 | |
| 140 | self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer) |
| 141 | self.audio_start_token = 101 |
| 142 | self.audio_end_token = 102 |
| 143 | |
| 144 | # Residual Acoustic LM |
| 145 | residual_lm_config = config.lm_config.model_copy(deep=True) |
| 146 | residual_lm_config.num_hidden_layers = config.residual_lm_num_layers |
| 147 | residual_lm_config.vocab_size = 0 |
| 148 | self.residual_lm = MiniCPMModel(residual_lm_config) |
| 149 | self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype)) |
| 150 | |
| 151 | # Local Encoder |
| 152 | encoder_config = config.lm_config.model_copy(deep=True) |
| 153 | encoder_config.hidden_size = config.encoder_config.hidden_dim |
| 154 | encoder_config.intermediate_size = config.encoder_config.ffn_dim |
| 155 | encoder_config.num_attention_heads = config.encoder_config.num_heads |
| 156 | encoder_config.num_hidden_layers = config.encoder_config.num_layers |
| 157 | encoder_config.kv_channels = config.encoder_config.kv_channels |
| 158 | encoder_config.vocab_size = 0 |
| 159 | self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim) |
| 160 | |
| 161 | # Local DiT |
| 162 | decoder_config = config.lm_config.model_copy(deep=True) |
| 163 | decoder_config.hidden_size = config.dit_config.hidden_dim |
| 164 | decoder_config.intermediate_size = config.dit_config.ffn_dim |
| 165 | decoder_config.num_attention_heads = config.dit_config.num_heads |
| 166 | decoder_config.num_hidden_layers = config.dit_config.num_layers |
| 167 | decoder_config.kv_channels = config.dit_config.kv_channels |
| 168 | decoder_config.vocab_size = 0 |
| 169 | self.feat_decoder = UnifiedCFM( |
nothing calls this directly
no test coverage detected