| 141 | return K2 |
| 142 | |
| 143 | class SOT: |
| 144 | def __init__(self,taskname='ACT',root_dir='./clustertemp/',d=200, |
| 145 | reg_e=0.1, reg_cl=0.1, reg_ce=0.1,rule='median'): |
| 146 | self.taskname=taskname |
| 147 | self.root_dir=root_dir |
| 148 | self.d=d |
| 149 | self.reg_e=reg_e |
| 150 | self.reg_cl=reg_cl |
| 151 | self.reg_ce=reg_ce |
| 152 | self.rule=rule |
| 153 | |
| 154 | def get_target(self,filepath,modelpath,targetname): |
| 155 | with open(filepath, 'r') as f: |
| 156 | s = f.read() |
| 157 | record = json.loads(s) |
| 158 | self.diag_t = record[targetname] |
| 159 | self.diag_g = joblib.load(modelpath) |
| 160 | |
| 161 | def get_source(self,filepath,sourcename): |
| 162 | with open(filepath,'r') as f: |
| 163 | s=f.read() |
| 164 | record1=json.loads(s) |
| 165 | self.diag_s=record1[sourcename] |
| 166 | |
| 167 | def partot_DA(self,Sx,Sy,Tx,b,xt1,ttt=1): |
| 168 | a1 = np.ones(len(Sx)) |
| 169 | M = cdist(Sx, Tx, metric='sqeuclidean') |
| 170 | M = M / np.median(M) |
| 171 | b=np.ones(len(Tx))/len(Tx) |
| 172 | T = entropic_partial_wasserstein(a1, b, M, self.reg_ce,m=1) |
| 173 | if np.sum(T)<0.5: |
| 174 | a=np.ones(len(Sx))/len(Sx) |
| 175 | T=np.outer(a,b) |
| 176 | gmm=self.diag_g |
| 177 | index=gmm.predict(xt1) |
| 178 | |
| 179 | a=T.dot(np.ones(len(Tx))) |
| 180 | b=(T.T).dot(np.ones(len(Sx))) |
| 181 | G=ot.da.sinkhorn_lpl1_mm(a,Sy,b,M,self.reg_e,self.reg_cl) |
| 182 | if np.sum(G)<0.5: |
| 183 | a=np.ones(len(Sx))/len(Sx) |
| 184 | G=np.outer(a,b) |
| 185 | transp_Xs_lpl1 = np.diag(1 / G.dot(np.ones(len(Tx)))) @ G.dot(Tx) |
| 186 | knn_clf = KNeighborsClassifier(n_neighbors=1) |
| 187 | knn_clf.fit(transp_Xs_lpl1, Sy) |
| 188 | Cls2 = knn_clf.predict(Tx) |
| 189 | return Cls2[index] |
| 190 | |
| 191 | def fit_predict(self, Sx, Sy, Tx, Ty,sfilepath,sourcename,tfilepath,tmodelpath,targetname): |
| 192 | gmm_source(Sx,Sy,sfilepath,sourcename) |
| 193 | self.get_source(sfilepath,sourcename) |
| 194 | gmm_target(Tx,1,self.d,self.root_dir,tfilepath,tmodelpath,targetname) |
| 195 | self.get_target(tfilepath,tmodelpath,targetname) |
| 196 | ss1 = self.diag_s |
| 197 | tt1 = self.diag_t |
| 198 | xns,yns=ss1['xn1'],ss1['yns'] |
| 199 | xntmu,xnt=tt1['xntmu'],tt1['xnt1'] |
| 200 | pred= self.partot_DA(xns,yns,xnt,xntmu,Tx,ttt=2) |