MCPcopy
hub / github.com/karpathy/neuraltalk / main

Function main

eval_sentence_predictions.py:16–107  ·  view source on GitHub ↗
(params)

Source from the content-addressed store, hash-verified

14from imagernn.imagernn_utils import decodeGenerator, eval_split
15
16def main(params):
17
18 # load the checkpoint
19 checkpoint_path = params['checkpoint_path']
20 max_images = params['max_images']
21
22 print 'loading checkpoint %s' % (checkpoint_path, )
23 checkpoint = pickle.load(open(checkpoint_path, 'rb'))
24 checkpoint_params = checkpoint['params']
25 dataset = checkpoint_params['dataset']
26 model = checkpoint['model']
27 dump_folder = params['dump_folder']
28
29 if dump_folder:
30 print 'creating dump folder ' + dump_folder
31 os.system('mkdir -p ' + dump_folder)
32
33 # fetch the data provider
34 dp = getDataProvider(dataset)
35
36 misc = {}
37 misc['wordtoix'] = checkpoint['wordtoix']
38 ixtoword = checkpoint['ixtoword']
39
40 blob = {} # output blob which we will dump to JSON for visualizing the results
41 blob['params'] = params
42 blob['checkpoint_params'] = checkpoint_params
43 blob['imgblobs'] = []
44
45 # iterate over all images in test set and predict sentences
46 BatchGenerator = decodeGenerator(checkpoint_params)
47 n = 0
48 all_references = []
49 all_candidates = []
50 for img in dp.iterImages(split = 'test', max_images = max_images):
51 n+=1
52 print 'image %d/%d:' % (n, max_images)
53 references = [' '.join(x['tokens']) for x in img['sentences']] # as list of lists of tokens
54 kwparams = { 'beam_size' : params['beam_size'] }
55 Ys = BatchGenerator.predict([{'image':img}], model, checkpoint_params, **kwparams)
56
57 img_blob = {} # we will build this up
58 img_blob['img_path'] = img['local_file_path']
59 img_blob['imgid'] = img['imgid']
60
61 if dump_folder:
62 # copy source file to some folder. This makes it easier to distribute results
63 # into a webpage, because all images that were predicted on are in a single folder
64 source_file = img['local_file_path']
65 target_file = os.path.join(dump_folder, os.path.basename(img['local_file_path']))
66 os.system('cp %s %s' % (source_file, target_file))
67
68 # encode the human-provided references
69 img_blob['references'] = []
70 for gtsent in references:
71 print 'GT: ' + gtsent
72 img_blob['references'].append({'text': gtsent})
73

Callers 1

Calls 5

getDataProviderFunction · 0.90
decodeGeneratorFunction · 0.90
eval_splitFunction · 0.90
iterImagesMethod · 0.80
predictMethod · 0.45

Tested by

no test coverage detected