MCPcopy
hub / github.com/jindongwang/transferlearning / get_class_center

Function get_class_center

code/traditional/pyEasyTL/EasyTL.py:42–72  ·  view source on GitHub ↗
(Xs,Ys,Xt,dist)

Source from the content-addressed store, hash-verified

40 return Dct_c
41
42def get_class_center(Xs,Ys,Xt,dist):
43
44 source_class_center = np.array([])
45 Dct = np.array([])
46 for i in np.unique(Ys):
47 sel_mask = Ys == i
48 X_i = Xs[sel_mask.flatten()]
49 mean_i = np.mean(X_i, axis=0)
50 if len(source_class_center) == 0:
51 source_class_center = mean_i.reshape(-1, 1)
52 else:
53 source_class_center = np.hstack((source_class_center, mean_i.reshape(-1, 1)))
54
55 if dist == "ma":
56 Dct_c = get_ma_dist(Xt, X_i)
57 elif dist == "euclidean":
58 Dct_c = np.sqrt(np.nansum((mean_i - Xt)**2, axis=1))
59 elif dist == "sqeuc":
60 Dct_c = np.nansum((mean_i - Xt)**2, axis=1)
61 elif dist == "cosine":
62 Dct_c = get_cosine_dist(Xt, mean_i)
63 elif dist == "rbf":
64 Dct_c = np.nansum((mean_i - Xt)**2, axis=1)
65 Dct_c = np.exp(- Dct_c / 1);
66
67 if len(Dct) == 0:
68 Dct = Dct_c.reshape(-1, 1)
69 else:
70 Dct = np.hstack((Dct, Dct_c.reshape(-1, 1)))
71
72 return source_class_center, Dct
73
74def EasyTL(Xs,Ys,Xt,Yt,intra_align="coral",dist="euclidean",lp="linear"):
75# Inputs:

Callers 1

EasyTLFunction · 0.85

Calls 3

get_ma_distFunction · 0.85
get_cosine_distFunction · 0.85
meanMethod · 0.45

Tested by

no test coverage detected