MCPcopy
hub / github.com/jindongwang/transferlearning / SOT

Class SOT

code/traditional/sot/SOT.py:143–202  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

141 return K2
142
143class SOT:
144 def __init__(self,taskname='ACT',root_dir='./clustertemp/',d=200,
145 reg_e=0.1, reg_cl=0.1, reg_ce=0.1,rule='median'):
146 self.taskname=taskname
147 self.root_dir=root_dir
148 self.d=d
149 self.reg_e=reg_e
150 self.reg_cl=reg_cl
151 self.reg_ce=reg_ce
152 self.rule=rule
153
154 def get_target(self,filepath,modelpath,targetname):
155 with open(filepath, 'r') as f:
156 s = f.read()
157 record = json.loads(s)
158 self.diag_t = record[targetname]
159 self.diag_g = joblib.load(modelpath)
160
161 def get_source(self,filepath,sourcename):
162 with open(filepath,'r') as f:
163 s=f.read()
164 record1=json.loads(s)
165 self.diag_s=record1[sourcename]
166
167 def partot_DA(self,Sx,Sy,Tx,b,xt1,ttt=1):
168 a1 = np.ones(len(Sx))
169 M = cdist(Sx, Tx, metric='sqeuclidean')
170 M = M / np.median(M)
171 b=np.ones(len(Tx))/len(Tx)
172 T = entropic_partial_wasserstein(a1, b, M, self.reg_ce,m=1)
173 if np.sum(T)<0.5:
174 a=np.ones(len(Sx))/len(Sx)
175 T=np.outer(a,b)
176 gmm=self.diag_g
177 index=gmm.predict(xt1)
178
179 a=T.dot(np.ones(len(Tx)))
180 b=(T.T).dot(np.ones(len(Sx)))
181 G=ot.da.sinkhorn_lpl1_mm(a,Sy,b,M,self.reg_e,self.reg_cl)
182 if np.sum(G)<0.5:
183 a=np.ones(len(Sx))/len(Sx)
184 G=np.outer(a,b)
185 transp_Xs_lpl1 = np.diag(1 / G.dot(np.ones(len(Tx)))) @ G.dot(Tx)
186 knn_clf = KNeighborsClassifier(n_neighbors=1)
187 knn_clf.fit(transp_Xs_lpl1, Sy)
188 Cls2 = knn_clf.predict(Tx)
189 return Cls2[index]
190
191 def fit_predict(self, Sx, Sy, Tx, Ty,sfilepath,sourcename,tfilepath,tmodelpath,targetname):
192 gmm_source(Sx,Sy,sfilepath,sourcename)
193 self.get_source(sfilepath,sourcename)
194 gmm_target(Tx,1,self.d,self.root_dir,tfilepath,tmodelpath,targetname)
195 self.get_target(tfilepath,tmodelpath,targetname)
196 ss1 = self.diag_s
197 tt1 = self.diag_t
198 xns,yns=ss1['xn1'],ss1['yns']
199 xntmu,xnt=tt1['xntmu'],tt1['xnt1']
200 pred= self.partot_DA(xns,yns,xnt,xntmu,Tx,ttt=2)

Callers 1

main.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected