(xs,ys,filepath,sourcename,slist=[],covtype='diag')
| 92 | return np.array(xmu),np.array(x1),np.array(x2),gmm |
| 93 | |
| 94 | def gmm_source(xs,ys,filepath,sourcename,slist=[],covtype='diag'): |
| 95 | if not os.path.exists(filepath): |
| 96 | ty = Counter(ys) |
| 97 | lys = len(ys) |
| 98 | lc = len(Counter(ys)) |
| 99 | ws = {} |
| 100 | for i in range(lc): |
| 101 | ws[i] = ty[i] / lys |
| 102 | if len(slist)==0: |
| 103 | slist=np.arange(1,lys+1) |
| 104 | for i in range(lc): |
| 105 | xtmu, xt1, xt2, gmmt = gmm_source_class(xs[np.where(ys == i)[0]], ws[i],slist,covtype=covtype) |
| 106 | yts = np.ones(len(xt1)) * i |
| 107 | if i == 0: |
| 108 | xn1, xn2, xmu = xt1, xt2, xtmu |
| 109 | yns = yts |
| 110 | else: |
| 111 | xmu = np.hstack((xmu, xtmu)) |
| 112 | xn1 = np.vstack((xn1, xt1)) |
| 113 | xn2 = np.vstack((xn2, xt2)) |
| 114 | yns = np.hstack((yns, yts)) |
| 115 | data={'xmu':xmu,'xn1':xn1,'xn2':xn2,'yns':yns} |
| 116 | record={sourcename:data} |
| 117 | with open(filepath,'w') as f: |
| 118 | json.dump(record,f,cls=MyEncoder) |
| 119 | |
| 120 | def entropic_partial_wasserstein(a, b, M, reg, m=1, numItermax=500, |
| 121 | stopThr=1e-100, verbose=False, log=False): |
no test coverage detected