(params, misc)
| 21 | |
| 22 | @staticmethod |
| 23 | def init(params, misc): |
| 24 | |
| 25 | # inputs |
| 26 | image_encoding_size = params.get('image_encoding_size', 128) |
| 27 | word_encoding_size = params.get('word_encoding_size', 128) |
| 28 | hidden_size = params.get('hidden_size', 128) |
| 29 | generator = params.get('generator', 'lstm') |
| 30 | vocabulary_size = len(misc['wordtoix']) |
| 31 | output_size = len(misc['ixtoword']) # these should match though |
| 32 | image_size = 4096 # size of CNN vectors hardcoded here |
| 33 | |
| 34 | if generator == 'lstm': |
| 35 | assert image_encoding_size == word_encoding_size, 'this implementation does not support different sizes for these parameters' |
| 36 | |
| 37 | # initialize the encoder models |
| 38 | model = {} |
| 39 | model['We'] = initw(image_size, image_encoding_size) # image encoder |
| 40 | model['be'] = np.zeros((1,image_encoding_size)) |
| 41 | model['Ws'] = initw(vocabulary_size, word_encoding_size) # word encoder |
| 42 | update = ['We', 'be', 'Ws'] |
| 43 | regularize = ['We', 'Ws'] |
| 44 | init_struct = { 'model' : model, 'update' : update, 'regularize' : regularize} |
| 45 | |
| 46 | # descend into the specific Generator and initialize it |
| 47 | Generator = decodeGenerator(generator) |
| 48 | generator_init_struct = Generator.init(word_encoding_size, hidden_size, output_size) |
| 49 | merge_init_structs(init_struct, generator_init_struct) |
| 50 | return init_struct |
| 51 | |
| 52 | @staticmethod |
| 53 | def forward(batch, model, params, misc, predict_mode = False): |
nothing calls this directly
no test coverage detected