| 91 | |
| 92 | # ############################## Network #################################### |
| 93 | class MLP(tl.models.Model): |
| 94 | |
| 95 | def __init__(self, name): |
| 96 | super(MLP, self).__init__(name=name) |
| 97 | self.h1 = tl.layers.Dense(64, tf.nn.tanh, in_channels=in_dim[0], W_init=tf.initializers.GlorotUniform()) |
| 98 | self.qvalue = tl.layers.Dense( |
| 99 | out_dim * atom_num, in_channels=64, name='q', W_init=tf.initializers.GlorotUniform() |
| 100 | ) |
| 101 | self.reshape = tl.layers.Reshape((-1, out_dim, atom_num)) |
| 102 | |
| 103 | def forward(self, ni): |
| 104 | qvalues = self.qvalue(self.h1(ni)) |
| 105 | return tf.nn.log_softmax(self.reshape(qvalues), 2) |
| 106 | |
| 107 | |
| 108 | class CNN(tl.models.Model): |
no outgoing calls