compile model,train and evaluate it,then save/load weight and model file. :param model: :param model_name: :param x: :param y: :param check_model_io: :return:
(model, model_name, x, y, check_model_io=True)
| 148 | |
| 149 | |
| 150 | def check_model(model, model_name, x, y, check_model_io=True): |
| 151 | ''' |
| 152 | compile model,train and evaluate it,then save/load weight and model file. |
| 153 | :param model: |
| 154 | :param model_name: |
| 155 | :param x: |
| 156 | :param y: |
| 157 | :param check_model_io: |
| 158 | :return: |
| 159 | ''' |
| 160 | early_stopping = EarlyStopping(monitor='val_acc', min_delta=0, verbose=1, patience=0, mode='max') |
| 161 | model_checkpoint = ModelCheckpoint(filepath='model.ckpt', monitor='val_acc', verbose=1, |
| 162 | save_best_only=True, |
| 163 | save_weights_only=False, mode='max', period=1) |
| 164 | |
| 165 | model.compile('adam', 'binary_crossentropy', |
| 166 | metrics=['binary_crossentropy', 'acc']) |
| 167 | model.fit(x, y, batch_size=100, epochs=1, validation_split=0.5, callbacks=[early_stopping, model_checkpoint]) |
| 168 | |
| 169 | print(model_name + 'test, train valid pass!') |
| 170 | torch.save(model.state_dict(), model_name + '_weights.h5') |
| 171 | model.load_state_dict(_torch_load_compat(model_name + '_weights.h5')) |
| 172 | os.remove(model_name + '_weights.h5') |
| 173 | print(model_name + 'test save load weight pass!') |
| 174 | if check_model_io: |
| 175 | torch.save(model, model_name + '.h5') |
| 176 | model = _torch_load_compat(model_name + '.h5') |
| 177 | os.remove(model_name + '.h5') |
| 178 | print(model_name + 'test save load model pass!') |
| 179 | print(model_name + 'test pass!') |
| 180 | |
| 181 | |
| 182 | def get_device(use_cuda=True): |