MCPcopy
hub / github.com/MrNothing/AI-Blocks / Run

Function Run

Sources/scripts/rnn.py:25–78  ·  view source on GitHub ↗
(self, graph, reuse=False)

Source from the content-addressed store, hash-verified

23self.variables = []
24
25def Run(self, graph, reuse=False):
26 Log(graph.get_shape())
27
28 if len(self.reshape)<=1:
29 graph = tf.reshape(graph, [-1, int(graph.get_shape()[1]), 1])
30 else:
31 graph = tf.reshape(graph, self.reshape)
32
33 if self.type=="GRU":
34 cells = []
35 for cell_n_hidden in self.hidden_units:
36 t_cell = tf.nn.rnn_cell.GRUCell(cell_n_hidden)
37 t_cell = tf.nn.rnn_cell.DropoutWrapper(t_cell, output_keep_prob=self.dropout)
38 cells.append(t_cell)
39 cell = tf.nn.rnn_cell.MultiRNNCell(cells)
40
41 with tf.variable_scope(self.name, reuse=reuse):
42 val, self.state = tf.nn.dynamic_rnn(cell, graph, dtype=tf.float32)
43
44 val = tf.transpose(val, [1, 0, 2])
45 last = tf.gather(val, int(val.get_shape()[0]) - 1)
46 self.weightV = tf.Variable(tf.truncated_normal([self.hidden_units[len(self.hidden_units)-1], self.n_classes]), name=self.name+"_W")
47 self.biasV = tf.Variable(tf.constant(0.1, shape=[self.n_classes]), name=self.name+"_b")
48 obj = tf.matmul(last, self.weightV) + self.biasV
49
50 self.variables.append(self.weightV)
51 self.variables.append(self.biasV)
52
53 Log(self.name+": "+str(self.hidden_units)+" => "+str(obj.get_shape()))
54
55 return obj
56 else:
57 cells = []
58 for cell_n_hidden in self.hidden_units:
59 t_cell = tf.nn.rnn_cell.LSTMCell(cell_n_hidden, state_is_tuple=True)
60 t_cell = tf.nn.rnn_cell.DropoutWrapper(t_cell, output_keep_prob=self.dropout)
61 cells.append(t_cell)
62 cell = tf.nn.rnn_cell.MultiRNNCell(cells)
63
64 with tf.variable_scope(self.name, reuse=reuse):
65 val, _ = tf.nn.dynamic_rnn(cell, graph, dtype=tf.float32)
66
67 val = tf.transpose(val, [1, 0, 2])
68 last = tf.gather(val, int(val.get_shape()[0]) - 1)
69 self.weightV = tf.Variable(tf.truncated_normal([self.hidden_units[len(self.hidden_units)-1], self.n_classes]), name=self.name+"_W")
70 self.biasV = tf.Variable(tf.constant(0.1, shape=[self.n_classes]), name=self.name+"_b")
71 obj = tf.matmul(last, self.weightV) + self.biasV
72
73 self.variables.append(self.weightV)
74 self.variables.append(self.biasV)
75
76 Log(self.name+": "+str(self.hidden_units)+" => "+str(obj.get_shape()))
77
78 return obj
79
80def getVariables(self):
81 return self.variables

Callers

nothing calls this directly

Calls 1

LogFunction · 0.85

Tested by

no test coverage detected