(params)
| 14 | from imagernn.imagernn_utils import decodeGenerator, eval_split |
| 15 | |
| 16 | def main(params): |
| 17 | |
| 18 | # load the checkpoint |
| 19 | checkpoint_path = params['checkpoint_path'] |
| 20 | max_images = params['max_images'] |
| 21 | |
| 22 | print 'loading checkpoint %s' % (checkpoint_path, ) |
| 23 | checkpoint = pickle.load(open(checkpoint_path, 'rb')) |
| 24 | checkpoint_params = checkpoint['params'] |
| 25 | dataset = checkpoint_params['dataset'] |
| 26 | model = checkpoint['model'] |
| 27 | dump_folder = params['dump_folder'] |
| 28 | |
| 29 | if dump_folder: |
| 30 | print 'creating dump folder ' + dump_folder |
| 31 | os.system('mkdir -p ' + dump_folder) |
| 32 | |
| 33 | # fetch the data provider |
| 34 | dp = getDataProvider(dataset) |
| 35 | |
| 36 | misc = {} |
| 37 | misc['wordtoix'] = checkpoint['wordtoix'] |
| 38 | ixtoword = checkpoint['ixtoword'] |
| 39 | |
| 40 | blob = {} # output blob which we will dump to JSON for visualizing the results |
| 41 | blob['params'] = params |
| 42 | blob['checkpoint_params'] = checkpoint_params |
| 43 | blob['imgblobs'] = [] |
| 44 | |
| 45 | # iterate over all images in test set and predict sentences |
| 46 | BatchGenerator = decodeGenerator(checkpoint_params) |
| 47 | n = 0 |
| 48 | all_references = [] |
| 49 | all_candidates = [] |
| 50 | for img in dp.iterImages(split = 'test', max_images = max_images): |
| 51 | n+=1 |
| 52 | print 'image %d/%d:' % (n, max_images) |
| 53 | references = [' '.join(x['tokens']) for x in img['sentences']] # as list of lists of tokens |
| 54 | kwparams = { 'beam_size' : params['beam_size'] } |
| 55 | Ys = BatchGenerator.predict([{'image':img}], model, checkpoint_params, **kwparams) |
| 56 | |
| 57 | img_blob = {} # we will build this up |
| 58 | img_blob['img_path'] = img['local_file_path'] |
| 59 | img_blob['imgid'] = img['imgid'] |
| 60 | |
| 61 | if dump_folder: |
| 62 | # copy source file to some folder. This makes it easier to distribute results |
| 63 | # into a webpage, because all images that were predicted on are in a single folder |
| 64 | source_file = img['local_file_path'] |
| 65 | target_file = os.path.join(dump_folder, os.path.basename(img['local_file_path'])) |
| 66 | os.system('cp %s %s' % (source_file, target_file)) |
| 67 | |
| 68 | # encode the human-provided references |
| 69 | img_blob['references'] = [] |
| 70 | for gtsent in references: |
| 71 | print 'GT: ' + gtsent |
| 72 | img_blob['references'].append({'text': gtsent}) |
| 73 |
no test coverage detected