MCPcopy
hub / github.com/jindongwang/transferlearning / gmm_source

Function gmm_source

code/traditional/sot/SOT.py:94–118  ·  view source on GitHub ↗
(xs,ys,filepath,sourcename,slist=[],covtype='diag')

Source from the content-addressed store, hash-verified

92 return np.array(xmu),np.array(x1),np.array(x2),gmm
93
94def gmm_source(xs,ys,filepath,sourcename,slist=[],covtype='diag'):
95 if not os.path.exists(filepath):
96 ty = Counter(ys)
97 lys = len(ys)
98 lc = len(Counter(ys))
99 ws = {}
100 for i in range(lc):
101 ws[i] = ty[i] / lys
102 if len(slist)==0:
103 slist=np.arange(1,lys+1)
104 for i in range(lc):
105 xtmu, xt1, xt2, gmmt = gmm_source_class(xs[np.where(ys == i)[0]], ws[i],slist,covtype=covtype)
106 yts = np.ones(len(xt1)) * i
107 if i == 0:
108 xn1, xn2, xmu = xt1, xt2, xtmu
109 yns = yts
110 else:
111 xmu = np.hstack((xmu, xtmu))
112 xn1 = np.vstack((xn1, xt1))
113 xn2 = np.vstack((xn2, xt2))
114 yns = np.hstack((yns, yts))
115 data={'xmu':xmu,'xn1':xn1,'xn2':xn2,'yns':yns}
116 record={sourcename:data}
117 with open(filepath,'w') as f:
118 json.dump(record,f,cls=MyEncoder)
119
120def entropic_partial_wasserstein(a, b, M, reg, m=1, numItermax=500,
121 stopThr=1e-100, verbose=False, log=False):

Callers 1

fit_predictMethod · 0.85

Calls 1

gmm_source_classFunction · 0.85

Tested by

no test coverage detected