(self,Sx,Sy,Tx,b,xt1,ttt=1)
| 165 | self.diag_s=record1[sourcename] |
| 166 | |
| 167 | def partot_DA(self,Sx,Sy,Tx,b,xt1,ttt=1): |
| 168 | a1 = np.ones(len(Sx)) |
| 169 | M = cdist(Sx, Tx, metric='sqeuclidean') |
| 170 | M = M / np.median(M) |
| 171 | b=np.ones(len(Tx))/len(Tx) |
| 172 | T = entropic_partial_wasserstein(a1, b, M, self.reg_ce,m=1) |
| 173 | if np.sum(T)<0.5: |
| 174 | a=np.ones(len(Sx))/len(Sx) |
| 175 | T=np.outer(a,b) |
| 176 | gmm=self.diag_g |
| 177 | index=gmm.predict(xt1) |
| 178 | |
| 179 | a=T.dot(np.ones(len(Tx))) |
| 180 | b=(T.T).dot(np.ones(len(Sx))) |
| 181 | G=ot.da.sinkhorn_lpl1_mm(a,Sy,b,M,self.reg_e,self.reg_cl) |
| 182 | if np.sum(G)<0.5: |
| 183 | a=np.ones(len(Sx))/len(Sx) |
| 184 | G=np.outer(a,b) |
| 185 | transp_Xs_lpl1 = np.diag(1 / G.dot(np.ones(len(Tx)))) @ G.dot(Tx) |
| 186 | knn_clf = KNeighborsClassifier(n_neighbors=1) |
| 187 | knn_clf.fit(transp_Xs_lpl1, Sy) |
| 188 | Cls2 = knn_clf.predict(Tx) |
| 189 | return Cls2[index] |
| 190 | |
| 191 | def fit_predict(self, Sx, Sy, Tx, Ty,sfilepath,sourcename,tfilepath,tmodelpath,targetname): |
| 192 | gmm_source(Sx,Sy,sfilepath,sourcename) |
no test coverage detected