MCPcopy Index your code
hub / github.com/MorvanZhou/Tensorflow-Tutorial / train

Function train

tutorial-contents/407_transfer_learning.py:162–182  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

160
161
162def train():
163 tigers_x, cats_x, tigers_y, cats_y = load_data()
164
165 # plot fake length distribution
166 plt.hist(tigers_y, bins=20, label='Tigers')
167 plt.hist(cats_y, bins=10, label='Cats')
168 plt.legend()
169 plt.xlabel('length')
170 plt.show()
171
172 xs = np.concatenate(tigers_x + cats_x, axis=0)
173 ys = np.concatenate((tigers_y, cats_y), axis=0)
174
175 vgg = Vgg16(vgg16_npy_path='./for_transfer_learning/vgg16.npy')
176 print('Net built')
177 for i in range(100):
178 b_idx = np.random.randint(0, len(xs), 6)
179 train_loss = vgg.train(xs[b_idx], ys[b_idx])
180 print(i, 'train loss: ', train_loss)
181
182 vgg.save('./for_transfer_learning/model/transfer_learn') # save learned fc layers
183
184
185def eval():

Callers

nothing calls this directly

Calls 4

trainMethod · 0.95
saveMethod · 0.95
load_dataFunction · 0.85
Vgg16Class · 0.85

Tested by

no test coverage detected