(sess, embedding_dim, batch_size)
| 50 | |
| 51 | |
| 52 | def build_model(sess, embedding_dim, batch_size): |
| 53 | model = CondGAN( |
| 54 | lr_imsize=cfg.TEST.LR_IMSIZE, |
| 55 | hr_lr_ratio=int(cfg.TEST.HR_IMSIZE/cfg.TEST.LR_IMSIZE)) |
| 56 | |
| 57 | embeddings = tf.placeholder( |
| 58 | tf.float32, [batch_size, embedding_dim], |
| 59 | name='conditional_embeddings') |
| 60 | with pt.defaults_scope(phase=pt.Phase.test): |
| 61 | with tf.variable_scope("g_net"): |
| 62 | c = sample_encoded_context(embeddings, model) |
| 63 | z = tf.random_normal([batch_size, cfg.Z_DIM]) |
| 64 | fake_images = model.get_generator(tf.concat(1, [c, z])) |
| 65 | with tf.variable_scope("hr_g_net"): |
| 66 | hr_c = sample_encoded_context(embeddings, model) |
| 67 | hr_fake_images = model.hr_get_generator(fake_images, hr_c) |
| 68 | |
| 69 | ckt_path = cfg.TEST.PRETRAINED_MODEL |
| 70 | if ckt_path.find('.ckpt') != -1: |
| 71 | print("Reading model parameters from %s" % ckt_path) |
| 72 | saver = tf.train.Saver(tf.all_variables()) |
| 73 | saver.restore(sess, ckt_path) |
| 74 | else: |
| 75 | print("Input a valid model path.") |
| 76 | return embeddings, fake_images, hr_fake_images |
| 77 | |
| 78 | |
| 79 | def drawCaption(img, caption): |
no test coverage detected