MCPcopy
hub / github.com/shenweichen/DeepCTR-Torch / check_model

Function check_model

tests/utils.py:150–179  ·  view source on GitHub ↗

compile model,train and evaluate it,then save/load weight and model file. :param model: :param model_name: :param x: :param y: :param check_model_io: :return:

(model, model_name, x, y, check_model_io=True)

Source from the content-addressed store, hash-verified

148
149
150def check_model(model, model_name, x, y, check_model_io=True):
151 '''
152 compile model,train and evaluate it,then save/load weight and model file.
153 :param model:
154 :param model_name:
155 :param x:
156 :param y:
157 :param check_model_io:
158 :return:
159 '''
160 early_stopping = EarlyStopping(monitor='val_acc', min_delta=0, verbose=1, patience=0, mode='max')
161 model_checkpoint = ModelCheckpoint(filepath='model.ckpt', monitor='val_acc', verbose=1,
162 save_best_only=True,
163 save_weights_only=False, mode='max', period=1)
164
165 model.compile('adam', 'binary_crossentropy',
166 metrics=['binary_crossentropy', 'acc'])
167 model.fit(x, y, batch_size=100, epochs=1, validation_split=0.5, callbacks=[early_stopping, model_checkpoint])
168
169 print(model_name + 'test, train valid pass!')
170 torch.save(model.state_dict(), model_name + '_weights.h5')
171 model.load_state_dict(_torch_load_compat(model_name + '_weights.h5'))
172 os.remove(model_name + '_weights.h5')
173 print(model_name + 'test save load weight pass!')
174 if check_model_io:
175 torch.save(model, model_name + '.h5')
176 model = _torch_load_compat(model_name + '.h5')
177 os.remove(model_name + '.h5')
178 print(model_name + 'test save load model pass!')
179 print(model_name + 'test pass!')
180
181
182def get_device(use_cuda=True):

Callers 15

test_CCPMFunction · 0.90
test_CCPM_without_seqFunction · 0.90
test_AFNFunction · 0.90
test_FiBiNETFunction · 0.85
test_AFMFunction · 0.85
test_DIENFunction · 0.85
test_DCNFunction · 0.85
test_DCNMixFunction · 0.85
test_WDLFunction · 0.85
test_PNNFunction · 0.85
test_DINFunction · 0.85
test_NFMFunction · 0.85

Calls 5

EarlyStoppingClass · 0.90
ModelCheckpointClass · 0.90
compileMethod · 0.80
fitMethod · 0.80
_torch_load_compatFunction · 0.70

Tested by 15

test_CCPMFunction · 0.72
test_CCPM_without_seqFunction · 0.72
test_AFNFunction · 0.72
test_FiBiNETFunction · 0.68
test_AFMFunction · 0.68
test_DIENFunction · 0.68
test_DCNFunction · 0.68
test_DCNMixFunction · 0.68
test_WDLFunction · 0.68
test_PNNFunction · 0.68
test_DINFunction · 0.68
test_NFMFunction · 0.68