MCPcopy
hub / github.com/tensorlayer/TensorLayer / find_top_model

Method find_top_model

tensorlayer/db.py:171–243  ·  view source on GitHub ↗

Finds and returns a model architecture and its parameters from the database which matches the requirement. Parameters ---------- sort : List of tuple PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/

(self, sort=None, model_name='model', **kwargs)

Source from the content-addressed store, hash-verified

169 return False
170
171 def find_top_model(self, sort=None, model_name='model', **kwargs):
172 """Finds and returns a model architecture and its parameters from the database which matches the requirement.
173
174 Parameters
175 ----------
176 sort : List of tuple
177 PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
178 model_name : str or None
179 The name/key of model.
180 kwargs : other events
181 Other events, such as name, accuracy, loss, step number and etc (optinal).
182
183 Examples
184 ---------
185 - see ``save_model``.
186
187 Returns
188 ---------
189 network : TensorLayer Model
190 Note that, the returned network contains all information of the document (record), e.g. if you saved accuracy in the document, you can get the accuracy by using ``net._accuracy``.
191 """
192 # print(kwargs) # {}
193 kwargs.update({'model_name': model_name})
194 self._fill_project_info(kwargs)
195
196 s = time.time()
197
198 d = self.db.Model.find_one(filter=kwargs, sort=sort)
199
200 # _temp_file_name = '_find_one_model_ztemp_file'
201 if d is not None:
202 params_id = d['params_id']
203 graphs = d['architecture']
204 _datetime = d['time']
205 # exists_or_mkdir(_temp_file_name, False)
206 # with open(os.path.join(_temp_file_name, 'graph.pkl'), 'wb') as file:
207 # pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL)
208 else:
209 print("[Database] FAIL! Cannot find model: {}".format(kwargs))
210 return False
211 try:
212 params = self._deserialization(self.model_fs.get(params_id).read())
213 # TODO : restore model and load weights
214 network = static_graph2net(graphs)
215 assign_weights(weights=params, network=network)
216 # np.savez(os.path.join(_temp_file_name, 'params.npz'), params=params)
217 #
218 # network = load_graph_and_params(name=_temp_file_name, sess=sess)
219 # del_folder(_temp_file_name)
220
221 pc = self.db.Model.find(kwargs)
222 print(
223 "[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s".format(
224 kwargs, sort, _datetime, round(time.time() - s, 2)
225 )
226 )
227
228 # FIXME : not sure what's this for

Callers 1

dispatch_tasks.pyFile · 0.80

Calls 6

_fill_project_infoMethod · 0.95
_deserializationMethod · 0.95
static_graph2netFunction · 0.90
assign_weightsFunction · 0.90
getMethod · 0.80
updateMethod · 0.45

Tested by

no test coverage detected