MCPcopy
hub / github.com/jindongwang/transferlearning / partot_DA

Method partot_DA

code/traditional/sot/SOT.py:167–189  ·  view source on GitHub ↗
(self,Sx,Sy,Tx,b,xt1,ttt=1)

Source from the content-addressed store, hash-verified

165 self.diag_s=record1[sourcename]
166
167 def partot_DA(self,Sx,Sy,Tx,b,xt1,ttt=1):
168 a1 = np.ones(len(Sx))
169 M = cdist(Sx, Tx, metric='sqeuclidean')
170 M = M / np.median(M)
171 b=np.ones(len(Tx))/len(Tx)
172 T = entropic_partial_wasserstein(a1, b, M, self.reg_ce,m=1)
173 if np.sum(T)<0.5:
174 a=np.ones(len(Sx))/len(Sx)
175 T=np.outer(a,b)
176 gmm=self.diag_g
177 index=gmm.predict(xt1)
178
179 a=T.dot(np.ones(len(Tx)))
180 b=(T.T).dot(np.ones(len(Sx)))
181 G=ot.da.sinkhorn_lpl1_mm(a,Sy,b,M,self.reg_e,self.reg_cl)
182 if np.sum(G)<0.5:
183 a=np.ones(len(Sx))/len(Sx)
184 G=np.outer(a,b)
185 transp_Xs_lpl1 = np.diag(1 / G.dot(np.ones(len(Tx)))) @ G.dot(Tx)
186 knn_clf = KNeighborsClassifier(n_neighbors=1)
187 knn_clf.fit(transp_Xs_lpl1, Sy)
188 Cls2 = knn_clf.predict(Tx)
189 return Cls2[index]
190
191 def fit_predict(self, Sx, Sy, Tx, Ty,sfilepath,sourcename,tfilepath,tmodelpath,targetname):
192 gmm_source(Sx,Sy,sfilepath,sourcename)

Callers 1

fit_predictMethod · 0.95

Calls 5

sumMethod · 0.80
dotMethod · 0.80
predictMethod · 0.45
fitMethod · 0.45

Tested by

no test coverage detected