MCPcopy
hub / github.com/hunkim/PyTorchZeroToAll / test

Function test

13_2_rnn_classification.py:161–182  ·  view source on GitHub ↗
(name=None)

Source from the content-addressed store, hash-verified

159
160# Testing cycle
161def test(name=None):
162 # Predict for a given name
163 if name:
164 input, seq_lengths, target = make_variables([name], [])
165 output = classifier(input, seq_lengths)
166 pred = output.data.max(1, keepdim=True)[1]
167 country_id = pred.cpu().numpy()[0][0]
168 print(name, "is", train_dataset.get_country(country_id))
169 return
170
171 print("evaluating trained model ...")
172 correct = 0
173 train_data_size = len(test_loader.dataset)
174
175 for names, countries in test_loader:
176 input, seq_lengths, target = make_variables(names, countries)
177 output = classifier(input, seq_lengths)
178 pred = output.data.max(1, keepdim=True)[1]
179 correct += pred.eq(target.data.view_as(pred)).cpu().sum()
180
181 print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
182 correct, train_data_size, 100. * correct / train_data_size))
183
184
185if __name__ == '__main__':

Callers 1

Calls 2

get_countryMethod · 0.80
make_variablesFunction · 0.70

Tested by

no test coverage detected