| 606 | |
| 607 | # but4.click(train_index, [exp_dir1], info3) |
| 608 | def train_index(exp_dir1, version19): |
| 609 | # exp_dir = "%s/logs/%s" % (now_dir, exp_dir1) |
| 610 | exp_dir = "logs/%s" % (exp_dir1) |
| 611 | os.makedirs(exp_dir, exist_ok=True) |
| 612 | feature_dir = ( |
| 613 | "%s/3_feature256" % (exp_dir) |
| 614 | if version19 == "v1" |
| 615 | else "%s/3_feature768" % (exp_dir) |
| 616 | ) |
| 617 | if not os.path.exists(feature_dir): |
| 618 | return "请先进行特征提取!" |
| 619 | listdir_res = list(os.listdir(feature_dir)) |
| 620 | if len(listdir_res) == 0: |
| 621 | return "请先进行特征提取!" |
| 622 | infos = [] |
| 623 | npys = [] |
| 624 | for name in sorted(listdir_res): |
| 625 | phone = np.load("%s/%s" % (feature_dir, name)) |
| 626 | npys.append(phone) |
| 627 | big_npy = np.concatenate(npys, 0) |
| 628 | big_npy_idx = np.arange(big_npy.shape[0]) |
| 629 | np.random.shuffle(big_npy_idx) |
| 630 | big_npy = big_npy[big_npy_idx] |
| 631 | if big_npy.shape[0] > 2e5: |
| 632 | infos.append("Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0]) |
| 633 | yield "\n".join(infos) |
| 634 | try: |
| 635 | big_npy = ( |
| 636 | MiniBatchKMeans( |
| 637 | n_clusters=10000, |
| 638 | verbose=True, |
| 639 | batch_size=256 * config.n_cpu, |
| 640 | compute_labels=False, |
| 641 | init="random", |
| 642 | ) |
| 643 | .fit(big_npy) |
| 644 | .cluster_centers_ |
| 645 | ) |
| 646 | except: |
| 647 | info = traceback.format_exc() |
| 648 | logger.info(info) |
| 649 | infos.append(info) |
| 650 | yield "\n".join(infos) |
| 651 | |
| 652 | np.save("%s/total_fea.npy" % exp_dir, big_npy) |
| 653 | n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39) |
| 654 | infos.append("%s,%s" % (big_npy.shape, n_ivf)) |
| 655 | yield "\n".join(infos) |
| 656 | index = faiss.index_factory(256 if version19 == "v1" else 768, "IVF%s,Flat" % n_ivf) |
| 657 | # index = faiss.index_factory(256if version19=="v1"else 768, "IVF%s,PQ128x4fs,RFlat"%n_ivf) |
| 658 | infos.append("training") |
| 659 | yield "\n".join(infos) |
| 660 | index_ivf = faiss.extract_index_ivf(index) # |
| 661 | index_ivf.nprobe = 1 |
| 662 | index.train(big_npy) |
| 663 | faiss.write_index( |
| 664 | index, |
| 665 | "%s/trained_IVF%s_Flat_nprobe_%s_%s_%s.index" |