| 160 | |
| 161 | |
| 162 | def train(): |
| 163 | tigers_x, cats_x, tigers_y, cats_y = load_data() |
| 164 | |
| 165 | # plot fake length distribution |
| 166 | plt.hist(tigers_y, bins=20, label='Tigers') |
| 167 | plt.hist(cats_y, bins=10, label='Cats') |
| 168 | plt.legend() |
| 169 | plt.xlabel('length') |
| 170 | plt.show() |
| 171 | |
| 172 | xs = np.concatenate(tigers_x + cats_x, axis=0) |
| 173 | ys = np.concatenate((tigers_y, cats_y), axis=0) |
| 174 | |
| 175 | vgg = Vgg16(vgg16_npy_path='./for_transfer_learning/vgg16.npy') |
| 176 | print('Net built') |
| 177 | for i in range(100): |
| 178 | b_idx = np.random.randint(0, len(xs), 6) |
| 179 | train_loss = vgg.train(xs[b_idx], ys[b_idx]) |
| 180 | print(i, 'train loss: ', train_loss) |
| 181 | |
| 182 | vgg.save('./for_transfer_learning/model/transfer_learn') # save learned fc layers |
| 183 | |
| 184 | |
| 185 | def eval(): |