MCPcopy Index your code
hub / github.com/lazyprogrammer/machine_learning_examples / fit

Method fit

hmm_class/hmmd_theano.py:26–88  ·  view source on GitHub ↗
(self, X, learning_rate=0.001, max_iter=10, V=None, p_cost=1.0, print_period=10)

Source from the content-addressed store, hash-verified

24 self.M = M # number of hidden states
25
26 def fit(self, X, learning_rate=0.001, max_iter=10, V=None, p_cost=1.0, print_period=10):
27 # train the HMM model using stochastic gradient descent
28
29 # determine V, the vocabulary size
30 # assume observables are already integers from 0..V-1
31 # X is a jagged array of observed sequences
32 if V is None:
33 V = max(max(x) for x in X) + 1
34 N = len(X)
35 print("number of train samples:", N)
36
37 pi0 = np.ones(self.M) / self.M # initial state distribution
38 A0 = random_normalized(self.M, self.M) # state transition matrix
39 B0 = random_normalized(self.M, V) # output distribution
40
41 thx, cost = self.set(pi0, A0, B0)
42
43 pi_update = self.pi - learning_rate*T.grad(cost, self.pi)
44 pi_update = pi_update / pi_update.sum()
45
46 A_update = self.A - learning_rate*T.grad(cost, self.A)
47 A_update = A_update / A_update.sum(axis=1).dimshuffle(0, 'x')
48
49 B_update = self.B - learning_rate*T.grad(cost, self.B)
50 B_update = B_update / B_update.sum(axis=1).dimshuffle(0, 'x')
51
52 updates = [
53 (self.pi, pi_update),
54 (self.A, A_update),
55 (self.B, B_update),
56 ]
57
58 train_op = theano.function(
59 inputs=[thx],
60 updates=updates,
61 allow_input_downcast=True,
62 )
63
64 # self.cost_op = theano.function(
65 # inputs=[thx],
66 # outputs=cost,
67 # allow_input_downcast=True,
68 # )
69
70 costs = []
71 for it in range(max_iter):
72 if it % print_period == 0:
73 print("it:", it)
74
75 for n in range(N):
76 # print "about to get the cost"
77 # this would of course be much faster if we didn't do this on
78 # every iteration of the loop
79 c = self.get_cost_multi(X, p_cost).sum()
80 costs.append(c)
81 train_op(X[n])
82
83 print("A:", self.A.get_value())

Callers 1

fit_coinFunction · 0.95

Calls 4

setMethod · 0.95
get_cost_multiMethod · 0.95
random_normalizedFunction · 0.70
gradMethod · 0.45

Tested by

no test coverage detected