(device='cuda')
| 540 | |
| 541 | |
| 542 | def make_tokenizer(device='cuda'): |
| 543 | generator_config = WaveCodec.Config( |
| 544 | resnet_config=ResNetStack.Config( |
| 545 | input_channels=1, |
| 546 | output_channels=1, |
| 547 | encode_channels=16, |
| 548 | decode_channel_multiplier=4, |
| 549 | kernel_size=7, |
| 550 | bias=True, |
| 551 | channel_ratios=(4, 8, 16, 16, 16, 16), |
| 552 | strides=(2, 2, 4, 5, 5, 5), |
| 553 | mode=None, |
| 554 | ), |
| 555 | use_weight_norm=True, |
| 556 | |
| 557 | compressor_config=GaussianZ.Config( |
| 558 | dim=None, |
| 559 | latent_dim=32, |
| 560 | |
| 561 | bias=True, |
| 562 | use_weight_norm=True |
| 563 | ), |
| 564 | |
| 565 | norm_stddev=0.05, |
| 566 | ) |
| 567 | checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97") |
| 568 | |
| 569 | tokenizer = generator_config() |
| 570 | |
| 571 | load_result = tokenizer.load_state_dict(checkpoint, strict=False) |
| 572 | print_colored(f"Loaded tokenizer state dict: {load_result}", "grey") |
| 573 | |
| 574 | tokenizer = tokenizer.eval() |
| 575 | # Only convert to bfloat16 if using CUDA |
| 576 | if device == 'cuda': |
| 577 | tokenizer = tokenizer.bfloat16() |
| 578 | tokenizer = tokenizer.to(device) |
| 579 | tokenizer.requires_grad_ = False |
| 580 | return tokenizer |
| 581 |
no test coverage detected