Evaluate the model Args: model_spec: (dict) contains the graph operations or nodes needed for evaluation model_dir: (string) directory containing config, weights and log params: (Params) contains hyperparameters of the model. Must define: num_epochs, trai
(model_spec, model_dir, params, restore_from)
| 48 | |
| 49 | |
| 50 | def evaluate(model_spec, model_dir, params, restore_from): |
| 51 | """Evaluate the model |
| 52 | |
| 53 | Args: |
| 54 | model_spec: (dict) contains the graph operations or nodes needed for evaluation |
| 55 | model_dir: (string) directory containing config, weights and log |
| 56 | params: (Params) contains hyperparameters of the model. |
| 57 | Must define: num_epochs, train_size, batch_size, eval_size, save_summary_steps |
| 58 | restore_from: (string) directory or file containing weights to restore the graph |
| 59 | """ |
| 60 | # Initialize tf.Saver |
| 61 | saver = tf.train.Saver() |
| 62 | |
| 63 | with tf.Session() as sess: |
| 64 | # Initialize the lookup table |
| 65 | sess.run(model_spec['variable_init_op']) |
| 66 | |
| 67 | # Reload weights from the weights subdirectory |
| 68 | save_path = os.path.join(model_dir, restore_from) |
| 69 | if os.path.isdir(save_path): |
| 70 | save_path = tf.train.latest_checkpoint(save_path) |
| 71 | saver.restore(sess, save_path) |
| 72 | |
| 73 | # Evaluate |
| 74 | num_steps = (params.eval_size + params.batch_size - 1) // params.batch_size |
| 75 | metrics = evaluate_sess(sess, model_spec, num_steps) |
| 76 | metrics_name = '_'.join(restore_from.split('/')) |
| 77 | save_path = os.path.join(model_dir, "metrics_test_{}.json".format(metrics_name)) |
| 78 | save_dict_to_json(metrics, save_path) |
no test coverage detected