MCPcopy Index your code
hub / github.com/TheAlgorithms/Python / fit

Method fit

machine_learning/sequential_minimum_optimization.py:78–140  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

76
77 # Calculate alphas using SMO algorithm
78 def fit(self):
79 k = self._k
80 state = None
81 while True:
82 # 1: Find alpha1, alpha2
83 try:
84 i1, i2 = self.choose_alpha.send(state)
85 state = None
86 except StopIteration:
87 print("Optimization done!\nEvery sample satisfy the KKT condition!")
88 break
89
90 # 2: calculate new alpha2 and new alpha1
91 y1, y2 = self.tags[i1], self.tags[i2]
92 a1, a2 = self.alphas[i1].copy(), self.alphas[i2].copy()
93 e1, e2 = self._e(i1), self._e(i2)
94 args = (i1, i2, a1, a2, e1, e2, y1, y2)
95 a1_new, a2_new = self._get_new_alpha(*args)
96 if not a1_new and not a2_new:
97 state = False
98 continue
99 self.alphas[i1], self.alphas[i2] = a1_new, a2_new
100
101 # 3: update threshold(b)
102 b1_new = np.float64(
103 -e1
104 - y1 * k(i1, i1) * (a1_new - a1)
105 - y2 * k(i2, i1) * (a2_new - a2)
106 + self._b
107 )
108 b2_new = np.float64(
109 -e2
110 - y2 * k(i2, i2) * (a2_new - a2)
111 - y1 * k(i1, i2) * (a1_new - a1)
112 + self._b
113 )
114 if 0.0 < a1_new < self._c:
115 b = b1_new
116 if 0.0 < a2_new < self._c:
117 b = b2_new
118 if not (np.float64(0) < a2_new < self._c) and not (
119 np.float64(0) < a1_new < self._c
120 ):
121 b = (b1_new + b2_new) / 2.0
122 b_old = self._b
123 self._b = b
124
125 # 4: update error, here we only calculate the error for non-bound samples
126 self._unbound = [i for i in self._all_samples if self._is_unbound(i)]
127 for s in self.unbound:
128 if s in (i1, i2):
129 continue
130 self._error[s] += (
131 y1 * (a1_new - a1) * k(i1, s)
132 + y2 * (a2_new - a2) * k(i2, s)
133 + (self._b - b_old)
134 )
135

Callers 3

test_cancer_dataFunction · 0.95
test_linear_kernelFunction · 0.95
test_rbf_kernelFunction · 0.95

Calls 4

_eMethod · 0.95
_get_new_alphaMethod · 0.95
_is_unboundMethod · 0.95
copyMethod · 0.80

Tested by 3

test_cancer_dataFunction · 0.76
test_linear_kernelFunction · 0.76
test_rbf_kernelFunction · 0.76