MCPcopy
hub / github.com/hanzhanggit/StackGAN / build_model

Function build_model

demo/birds_skip_thought_demo.py:51–75  ·  view source on GitHub ↗
(sess, embedding_dim, batch_size)

Source from the content-addressed store, hash-verified

49
50
51def build_model(sess, embedding_dim, batch_size):
52 model = CondGAN(
53 lr_imsize=cfg.TEST.LR_IMSIZE,
54 hr_lr_ratio=int(cfg.TEST.HR_IMSIZE/cfg.TEST.LR_IMSIZE))
55
56 embeddings = tf.placeholder(
57 tf.float32, [batch_size, embedding_dim],
58 name='conditional_embeddings')
59 with pt.defaults_scope(phase=pt.Phase.test):
60 with tf.variable_scope("g_net"):
61 c = sample_encoded_context(embeddings, model)
62 z = tf.random_normal([batch_size, cfg.Z_DIM])
63 fake_images = model.get_generator(tf.concat(1, [c, z]))
64 with tf.variable_scope("hr_g_net"):
65 hr_c = sample_encoded_context(embeddings, model)
66 hr_fake_images = model.hr_get_generator(fake_images, hr_c)
67
68 ckt_path = cfg.TEST.PRETRAINED_MODEL
69 if ckt_path.find('.ckpt') != -1:
70 print("Reading model parameters from %s" % ckt_path)
71 saver = tf.train.Saver(tf.all_variables())
72 saver.restore(sess, ckt_path)
73 else:
74 print("Input a valid model path.")
75 return embeddings, fake_images, hr_fake_images
76
77
78def drawCaption(img, caption):

Callers 1

Calls 4

get_generatorMethod · 0.95
hr_get_generatorMethod · 0.95
CondGANClass · 0.90
sample_encoded_contextFunction · 0.70

Tested by

no test coverage detected