()
| 181 | |
| 182 | |
| 183 | def main(): |
| 184 | parser = get_parser() |
| 185 | args = parser.parse_args() |
| 186 | |
| 187 | # Main process thread setting |
| 188 | torch.set_num_threads(2) |
| 189 | |
| 190 | mp.set_start_method("spawn", force=True) |
| 191 | |
| 192 | formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
| 193 | logging.basicConfig(format=formatter, level=logging.INFO, force=True) |
| 194 | |
| 195 | # Prepare paths |
| 196 | sv_model_path = os.path.join( |
| 197 | args.model_dir, "speaker_similarity/wavlm_large_finetune.pth" |
| 198 | ) |
| 199 | ssl_model_path = os.path.join(args.model_dir, "speaker_similarity/wavlm_large/") |
| 200 | |
| 201 | if not os.path.exists(sv_model_path) or not os.path.exists(ssl_model_path): |
| 202 | logging.error("Model files not found. Please check --model-dir.") |
| 203 | sys.exit(1) |
| 204 | |
| 205 | logging.info(f"Calculating SIM-o for {args.wav_path}") |
| 206 | # Read list |
| 207 | samples = read_test_list(args.test_list) |
| 208 | |
| 209 | # Setup Parallel Processing |
| 210 | num_gpus = torch.cuda.device_count() |
| 211 | assert num_gpus > 0, "No GPU found. GPU is required." |
| 212 | total_procs = num_gpus * args.nj_per_gpu |
| 213 | |
| 214 | logging.info( |
| 215 | f"Starting evaluation with {total_procs} processes " f"on {num_gpus} GPUs." |
| 216 | ) |
| 217 | |
| 218 | manager = mp.Manager() |
| 219 | rank_queue = manager.Queue() |
| 220 | |
| 221 | for rank in list(range(num_gpus)) * args.nj_per_gpu: |
| 222 | rank_queue.put(rank) |
| 223 | |
| 224 | scores = [] |
| 225 | |
| 226 | fout = None |
| 227 | if args.decode_path: |
| 228 | os.makedirs(os.path.dirname(args.decode_path), exist_ok=True) |
| 229 | fout = open(args.decode_path, "w", encoding="utf8") |
| 230 | logging.info(f"Saving detailed SIM-o results to: {args.decode_path}") |
| 231 | fout.write("Prompt-path\tEval-path\tSIM-o\n") |
| 232 | |
| 233 | try: |
| 234 | with ProcessPoolExecutor( |
| 235 | max_workers=total_procs, |
| 236 | initializer=worker_init, |
| 237 | initargs=( |
| 238 | rank_queue, |
| 239 | sv_model_path, |
| 240 | ssl_model_path, |
no test coverage detected