MCPcopy
hub / github.com/yenchenlin/DeepLearningFlappyBird / createNetwork

Function createNetwork

deep_q_network.py:38–76  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

36 return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME")
37
38def createNetwork():
39 # network weights
40 W_conv1 = weight_variable([8, 8, 4, 32])
41 b_conv1 = bias_variable([32])
42
43 W_conv2 = weight_variable([4, 4, 32, 64])
44 b_conv2 = bias_variable([64])
45
46 W_conv3 = weight_variable([3, 3, 64, 64])
47 b_conv3 = bias_variable([64])
48
49 W_fc1 = weight_variable([1600, 512])
50 b_fc1 = bias_variable([512])
51
52 W_fc2 = weight_variable([512, ACTIONS])
53 b_fc2 = bias_variable([ACTIONS])
54
55 # input layer
56 s = tf.placeholder("float", [None, 80, 80, 4])
57
58 # hidden layers
59 h_conv1 = tf.nn.relu(conv2d(s, W_conv1, 4) + b_conv1)
60 h_pool1 = max_pool_2x2(h_conv1)
61
62 h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2, 2) + b_conv2)
63 #h_pool2 = max_pool_2x2(h_conv2)
64
65 h_conv3 = tf.nn.relu(conv2d(h_conv2, W_conv3, 1) + b_conv3)
66 #h_pool3 = max_pool_2x2(h_conv3)
67
68 #h_pool3_flat = tf.reshape(h_pool3, [-1, 256])
69 h_conv3_flat = tf.reshape(h_conv3, [-1, 1600])
70
71 h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat, W_fc1) + b_fc1)
72
73 # readout layer
74 readout = tf.matmul(h_fc1, W_fc2) + b_fc2
75
76 return s, readout, h_fc1
77
78def trainNetwork(s, readout, h_fc1, sess):
79 # define the cost function

Callers 1

playGameFunction · 0.85

Calls 4

weight_variableFunction · 0.85
bias_variableFunction · 0.85
conv2dFunction · 0.85
max_pool_2x2Function · 0.85

Tested by

no test coverage detected