Finds and returns a model architecture and its parameters from the database which matches the requirement. Parameters ---------- sort : List of tuple PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/
(self, sort=None, model_name='model', **kwargs)
| 169 | return False |
| 170 | |
| 171 | def find_top_model(self, sort=None, model_name='model', **kwargs): |
| 172 | """Finds and returns a model architecture and its parameters from the database which matches the requirement. |
| 173 | |
| 174 | Parameters |
| 175 | ---------- |
| 176 | sort : List of tuple |
| 177 | PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details. |
| 178 | model_name : str or None |
| 179 | The name/key of model. |
| 180 | kwargs : other events |
| 181 | Other events, such as name, accuracy, loss, step number and etc (optinal). |
| 182 | |
| 183 | Examples |
| 184 | --------- |
| 185 | - see ``save_model``. |
| 186 | |
| 187 | Returns |
| 188 | --------- |
| 189 | network : TensorLayer Model |
| 190 | Note that, the returned network contains all information of the document (record), e.g. if you saved accuracy in the document, you can get the accuracy by using ``net._accuracy``. |
| 191 | """ |
| 192 | # print(kwargs) # {} |
| 193 | kwargs.update({'model_name': model_name}) |
| 194 | self._fill_project_info(kwargs) |
| 195 | |
| 196 | s = time.time() |
| 197 | |
| 198 | d = self.db.Model.find_one(filter=kwargs, sort=sort) |
| 199 | |
| 200 | # _temp_file_name = '_find_one_model_ztemp_file' |
| 201 | if d is not None: |
| 202 | params_id = d['params_id'] |
| 203 | graphs = d['architecture'] |
| 204 | _datetime = d['time'] |
| 205 | # exists_or_mkdir(_temp_file_name, False) |
| 206 | # with open(os.path.join(_temp_file_name, 'graph.pkl'), 'wb') as file: |
| 207 | # pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL) |
| 208 | else: |
| 209 | print("[Database] FAIL! Cannot find model: {}".format(kwargs)) |
| 210 | return False |
| 211 | try: |
| 212 | params = self._deserialization(self.model_fs.get(params_id).read()) |
| 213 | # TODO : restore model and load weights |
| 214 | network = static_graph2net(graphs) |
| 215 | assign_weights(weights=params, network=network) |
| 216 | # np.savez(os.path.join(_temp_file_name, 'params.npz'), params=params) |
| 217 | # |
| 218 | # network = load_graph_and_params(name=_temp_file_name, sess=sess) |
| 219 | # del_folder(_temp_file_name) |
| 220 | |
| 221 | pc = self.db.Model.find(kwargs) |
| 222 | print( |
| 223 | "[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s".format( |
| 224 | kwargs, sort, _datetime, round(time.time() - s, 2) |
| 225 | ) |
| 226 | ) |
| 227 | |
| 228 | # FIXME : not sure what's this for |
no test coverage detected